{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 快速入门：GNMT v2的MindSpore实现\n",
    "\n",
    "## 1. 项目概述\n",
    "### 1.1 机器翻译\n",
    "\n",
    "### 1.2 GNMT v2模型\n",
    "GNMT v2 模型类似于Google 的 Neural Machine Translation System: Bridging the Gap between Human and Machine Translation中描述的模型，主要用于语料库翻译\n",
    "\n",
    "GNMTv2 模型主要由编码器、解码器和注意力机制组成，其中编码器和解码器使用共享的词嵌入向量。编码器：由四个长短期记忆（LSTM）层组成。第一个 LSTM 层是双向的，而其他三层是单向的。解码器：由四个单向 LSTM 层和一个全连接分类器组成。LSTM 的输出嵌入维度为 1024。注意机制：使用标准化的 Bahdanau 注意机制。首先，解码器的第一层输出作为注意力机制的输入。然后，注意力机制的计算结果连接到解码器LSTM的输入，作为后续LSTM层的输入\n",
    "这个模型由一个有 8 个编码器和 8 个解码器层的深度 LSTM 网络组成，使用残差连接以及从解码器网络到编码器的注意力模块连接，主要特点有：\n",
    "\n",
    "- 为了提高并行度，从而减少训练时间，该模型使用注意机制将解码器的底层连接到编码器的顶层；\n",
    "- 为了提高最终的翻译速度，在推理计算中采用了低精度的算法（限制数值范围）；\n",
    "- 为了改进对罕见单词的处理，作者将单词分为有限的公共子单词单元（称之为“单词块”），用于输入和输出，该方法在“字符”分隔模型的灵活性和“单词”分隔模型的效率之间提供了良好的平衡，自然地处理了罕见单词的翻译，最终提高了系统的整体准确性；\n",
    "- 在测试过程中，通过Beam search 算法添加了长度标准化过程和覆盖惩罚，鼓励生成输出语句，能够尽可能覆盖源语句中的所有单词；\n",
    "- 为了直接优化翻译任务的 BLEU 分数，作者还使用了强化学习来细化模型；\n",
    "\n",
    "在 WMT14 的英法和英德基准测试中，GNMT 达到了最先进水平的竞争性结果，与谷歌的基于短语的生产系统相比，通过对一组孤立的简单句子进行人工并排评估，它平均减少了60%的翻译错误；\n",
    "\n",
    "### 1.3 环境要求\n",
    "- 硬件\n",
    "    1. 带GPU显卡或华为Ascend处理器的服务器\n",
    "- 软件\n",
    "    1. python3.5+\n",
    "    1. `mindspore`\n",
    "    2. `numpy`\n",
    "    3. `sacrebleu==1.4.14`\n",
    "    4. `sacremoses==0.0.35`\n",
    "    5. `subword_nmt==0.3.7`\n",
    "\n",
    "## 2. 数据准备\n",
    "### 2.1 数据下载\n",
    "本项目使用WMT英语-德语数据集。使用一下脚本下载并转换数据到指定格式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing to data/wmt16_de_en. To change this, set the OUTPUT_DIR environment variable.\n",
      "Extracting all files...\n",
      "europarl-v7.de-en.de\n",
      "europarl-v7.de-en.en\n",
      "commoncrawl.cs-en.annotation\n",
      "commoncrawl.cs-en.cs\n",
      "commoncrawl.cs-en.en\n",
      "commoncrawl.de-en.annotation\n",
      "commoncrawl.de-en.de\n",
      "commoncrawl.de-en.en\n",
      "commoncrawl.es-en.annotation\n",
      "commoncrawl.es-en.en\n",
      "commoncrawl.es-en.es\n",
      "commoncrawl.fr-en.annotation\n",
      "commoncrawl.fr-en.en\n",
      "commoncrawl.fr-en.fr\n",
      "commoncrawl.ru-en.annotation\n",
      "commoncrawl.ru-en.en\n",
      "commoncrawl.ru-en.ru\n",
      "training-parallel-nc-v11/\n",
      "training-parallel-nc-v11/news-commentary-v11.ru-en.ru\n",
      "training-parallel-nc-v11/news-commentary-v11.cs-en.en\n",
      "training-parallel-nc-v11/news-commentary-v11.de-en.de\n",
      "training-parallel-nc-v11/news-commentary-v11.ru-en.en\n",
      "training-parallel-nc-v11/news-commentary-v11.cs-en.cs\n",
      "training-parallel-nc-v11/news-commentary-v11.de-en.en\n",
      "dev/\n",
      "dev/newstest2009-ref.fr.sgm\n",
      "dev/newstest2013.es\n",
      "dev/newstest2014-deen-src.de.sgm\n",
      "dev/newstest2015-ruen-src.ru.sgm\n",
      "dev/newstest2010-ref.de.sgm\n",
      "dev/newstest2012-src.fr.sgm\n",
      "dev/newstest2014-ruen-ref.ru.sgm\n",
      "dev/news-test2008.en\n",
      "dev/news-test2008.es\n",
      "dev/newstest2009-ref.hu.sgm\n",
      "dev/newstest2014-csen-ref.en.sgm\n",
      "dev/newsdiscussdev2015-enfr-src.en.sgm\n",
      "dev/newstest2010.cs\n",
      "dev/news-test2008-src.hu.sgm\n",
      "dev/.newsdev2014-ref.en.sgm.swp\n",
      "dev/newstest2011-ref.cs.sgm\n",
      "dev/newstest2011-ref.fr.sgm\n",
      "dev/newsdev2016-enro-ref.ro.sgm\n",
      "dev/newstest2011.cs\n",
      "dev/newstest2009.es\n",
      "dev/newstest2011.en\n",
      "dev/newsdev2015-enfi-src.en.sgm\n",
      "dev/newstest2013.cs\n",
      "dev/newstest2012-ref.es.sgm\n",
      "dev/newstest2014-csen-ref.cs.sgm\n",
      "dev/newsdev2014-src.hi.sgm\n",
      "dev/newstest2015-encs-src.en.sgm\n",
      "dev/newsdev2014-src.en.sgm\n",
      "dev/newsdev2015-enfi-ref.fi.sgm\n",
      "dev/newstest2011-ref.es.sgm\n",
      "dev/newstest2013-src.ru.sgm\n",
      "dev/newstest2012-src.de.sgm\n",
      "dev/newsdev2016-tren-ref.en.sgm\n",
      "dev/newstest2011-src.fr.sgm\n",
      "dev/newssyscomb2009-src.de.sgm\n",
      "dev/newstest2012-src.es.sgm\n",
      "dev/newstest2010-ref.cs.sgm\n",
      "dev/newstest2014-hien-ref.hi.sgm\n",
      "dev/newssyscomb2009.de\n",
      "dev/newstest2011-ref.en.sgm\n",
      "dev/news-test2008.cs\n",
      "dev/newstest2010.en\n",
      "dev/newssyscomb2009.fr\n",
      "dev/newstest2012-ref.en.sgm\n",
      "dev/news-test2008.de\n",
      "dev/newstest2011.de\n",
      "dev/newstest2012.es\n",
      "dev/newsdev2016-entr-ref.tr.sgm\n",
      "dev/newstest2011-ref.de.sgm\n",
      "dev/newsdev2014-ref.hi.sgm\n",
      "dev/newstest2013-src.de.sgm\n",
      "dev/newstest2012-ref.fr.sgm\n",
      "dev/newstest2009.de\n",
      "dev/newstest2012.en\n",
      "dev/news-test2008-ref.cs.sgm\n",
      "dev/newstest2013-ref.fr.sgm\n",
      "dev/newsdev2014.hi\n",
      "dev/newstest2011-src.cs.sgm\n",
      "dev/newssyscomb2009-src.fr.sgm\n",
      "dev/newstest2012.ru\n",
      "dev/newstest2010-ref.es.sgm\n",
      "dev/newstest2010-src.es.sgm\n",
      "dev/news-test2008.fr\n",
      "dev/newstest2009.en\n",
      "dev/newstest2014-ruen-src.ru.sgm\n",
      "dev/newssyscomb2009-ref.cs.sgm\n",
      "dev/newstest2010-src.fr.sgm\n",
      "dev/newssyscomb2009-src.en.sgm\n",
      "dev/newstest2015-enru-ref.ru.sgm\n",
      "dev/newstest2015-ende-ref.de.sgm\n",
      "dev/newstest2013-ref.ru.sgm\n",
      "dev/newssyscomb2009-src.it.sgm\n",
      "dev/newsdiscusstest2015-enfr-src.en.sgm\n",
      "dev/newstest2015-fien-ref.en.sgm\n",
      "dev/newstest2010-src.en.sgm\n",
      "dev/newstest2009.fr\n",
      "dev/newstest2015-ruen-ref.en.sgm\n",
      "dev/newstest2013-src.es.sgm\n",
      "dev/newstest2014-hien-ref.en.sgm\n",
      "dev/news-test2008-src.en.sgm\n",
      "dev/newstest2012-ref.cs.sgm\n",
      "dev/news-test2008-ref.es.sgm\n",
      "dev/news-test2008-ref.fr.sgm\n",
      "dev/newstest2014-ruen-ref.en.sgm\n",
      "dev/news-test2008-src.es.sgm\n",
      "dev/newstest2014-fren-src.en.sgm\n",
      "dev/newstest2012-ref.de.sgm\n",
      "dev/newstest2014-csen-src.cs.sgm\n",
      "dev/newstest2014-csen-src.en.sgm\n",
      "dev/newstest2011-src.de.sgm\n",
      "dev/newssyscomb2009-src.cs.sgm\n",
      "dev/newstest2015-enfi-ref.fi.sgm\n",
      "dev/newstest2009-src.it.sgm\n",
      "dev/newstest2010-src.de.sgm\n",
      "dev/newstest2009-ref.cs.sgm\n",
      "dev/newssyscomb2009-ref.es.sgm\n",
      "dev/newstest2014-deen-src.en.sgm\n",
      "dev/newsdiscusstest2015-fren-ref.en.sgm\n",
      "dev/newstest2012.fr\n",
      "dev/newsdiscusstest2015-enfr-ref.fr.sgm\n",
      "dev/newsdev2016-enro-src.en.sgm\n",
      "dev/newstest2009-src.es.sgm\n",
      "dev/newstest2013-src.fr.sgm\n",
      "dev/newstest2015-deen-src.de.sgm\n",
      "dev/newsdev2015-fien-src.fi.sgm\n",
      "dev/newsdiscusstest2015-fren-src.fr.sgm\n",
      "dev/newstest2014-ruen-src.en.sgm\n",
      "dev/newstest2012-src.en.sgm\n",
      "dev/newstest2013.fr\n",
      "dev/newstest2015-enru-src.en.sgm\n",
      "dev/newstest2009-ref.es.sgm\n",
      "dev/newstest2011.fr\n",
      "dev/newstest2009-ref.en.sgm\n",
      "dev/newstest2015-enfi-src.en.sgm\n",
      "dev/newstest2009-src.xx.sgm\n",
      "dev/newstest2015-encs-ref.cs.sgm\n",
      "dev/newstest2013.ru\n",
      "dev/newstest2009.cs\n",
      "dev/newsdev2014.en\n",
      "dev/newstest2014-fren-ref.fr.sgm\n",
      "dev/news-test2008-ref.en.sgm\n",
      "dev/newssyscomb2009.es\n",
      "dev/news-test2008-src.cs.sgm\n",
      "dev/newsdev2016-roen-src.ro.sgm\n",
      "dev/.newstest2013-ref.en.sgm.swp\n",
      "dev/newssyscomb2009-ref.hu.sgm\n",
      "dev/newstest2010.de\n",
      "dev/newstest2013-ref.cs.sgm\n",
      "dev/newstest2013-ref.de.sgm\n",
      "dev/newstest2009-src.cs.sgm\n",
      "dev/newssyscomb2009.en\n",
      "dev/newssyscomb2009-ref.it.sgm\n",
      "dev/newstest2009-ref.it.sgm\n",
      "dev/newstest2010-ref.fr.sgm\n",
      "dev/newstest2015-csen-src.cs.sgm\n",
      "dev/newsdev2016-entr-src.en.sgm\n",
      "dev/newstest2010.es\n",
      "dev/news-test2008-src.de.sgm\n",
      "dev/newstest2013.en\n",
      "dev/newsdev2016-roen-ref.en.sgm\n",
      "dev/newstest2009-src.de.sgm\n",
      "dev/newstest2010-ref.en.sgm\n",
      "dev/newstest2011-src.es.sgm\n",
      "dev/newssyscomb2009-ref.en.sgm\n",
      "dev/newstest2014-fren-ref.en.sgm\n",
      "dev/newstest2012.cs\n",
      "dev/newstest2009-src.hu.sgm\n",
      "dev/newstest2009-src.fr.sgm\n",
      "dev/newstest2015-ende-src.en.sgm\n",
      "dev/newstest2013-src.cs.sgm\n",
      "dev/newstest2014-hien-src.hi.sgm\n",
      "dev/news-test2008-ref.hu.sgm\n",
      "dev/newstest2015-csen-ref.en.sgm\n",
      "dev/newstest2013-ref.es.sgm\n",
      "dev/newstest2013-ref.en.sgm\n",
      "dev/newstest2010-src.cs.sgm\n",
      "dev/newstest2010.fr\n",
      "dev/newstest2015-deen-ref.en.sgm\n",
      "dev/newstest2011.es\n",
      "dev/newsdev2016-tren-src.tr.sgm\n",
      "dev/newstest2013.de\n",
      "dev/newstest2014-fren-src.fr.sgm\n",
      "dev/newsdiscussdev2015-fren-ref.en.sgm\n",
      "dev/newsdiscussdev2015-fren-src.fr.sgm\n",
      "dev/newstest2014-deen-ref.de.sgm\n",
      "dev/newstest2013-src.en.sgm\n",
      "dev/newssyscomb2009-ref.fr.sgm\n",
      "dev/newssyscomb2009-ref.de.sgm\n",
      "dev/newstest2009-src.en.sgm\n",
      "dev/newstest2009-ref.de.sgm\n",
      "dev/newsdiscussdev2015-enfr-ref.fr.sgm\n",
      "dev/newssyscomb2009.cs\n",
      "dev/newstest2012-ref.ru.sgm\n",
      "dev/newstest2014-hien-src.en.sgm\n",
      "dev/news-test2008-src.fr.sgm\n",
      "dev/newsdev2015-fien-ref.en.sgm\n",
      "dev/newsdev2014-ref.en.sgm\n",
      "dev/newstest2015-fien-src.fi.sgm\n",
      "dev/news-test2008-ref.de.sgm\n",
      "dev/newstest2012-src.ru.sgm\n",
      "dev/newssyscomb2009-src.es.sgm\n",
      "dev/newssyscomb2009-src.hu.sgm\n",
      "dev/newstest2014-deen-ref.en.sgm\n",
      "dev/newstest2012.de\n",
      "dev/newstest2011-src.en.sgm\n",
      "dev/newstest2012-src.cs.sgm\n",
      "test/newstest2016-csen-ref.en.sgm\n",
      "test/newstest2016-csen-src.cs.sgm\n",
      "test/newstest2016-deen-ref.en.sgm\n",
      "test/newstest2016-deen-src.de.sgm\n",
      "test/newstest2016-encs-ref.cs.sgm\n",
      "test/newstest2016-encs-src.en.sgm\n",
      "test/newstest2016-ende-ref.de.sgm\n",
      "test/newstest2016-ende-src.en.sgm\n",
      "test/newstest2016-enfi-ref.fi.sgm\n",
      "test/newstest2016-enfi-src.en.sgm\n",
      "test/newstest2016-enro-ref.ro.sgm\n",
      "test/newstest2016-enro-src.en.sgm\n",
      "test/newstest2016-enru-ref.ru.sgm\n",
      "test/newstest2016-enru-src.en.sgm\n",
      "test/newstest2016-entr-ref.tr.sgm\n",
      "test/newstest2016-entr-src.en.sgm\n",
      "test/newstest2016-fien-ref.en.sgm\n",
      "test/newstest2016-fien-src.fi.sgm\n",
      "test/newstest2016-roen-ref.en.sgm\n",
      "test/newstest2016-roen-src.ro.sgm\n",
      "test/newstest2016-ruen-ref.en.sgm\n",
      "test/newstest2016-ruen-src.ru.sgm\n",
      "test/newstest2016-tren-ref.en.sgm\n",
      "test/newstest2016-tren-src.tr.sgm\n",
      "test/newstestB2016-enfi-ref.fi.sgm\n",
      "test/newstestB2016-enfi-src.en.sgm\n",
      "Extracting files done!\n",
      "4562102 data/wmt16_de_en/train.en\n",
      "4562102 data/wmt16_de_en/train.de\n",
      "Cloning moses for data processing\n",
      "Cloning into 'data/wmt16_de_en/mosesdecoder'...\n",
      "remote: Enumerating objects: 148097, done.\u001b[K\n",
      "remote: Counting objects: 100% (525/525), done.\u001b[K\n",
      "remote: Compressing objects: 100% (229/229), done.\u001b[K\n",
      "remote: Total 148097 (delta 323), reused 441 (delta 292), pack-reused 147572\u001b[K\n",
      "Receiving objects: 100% (148097/148097), 129.88 MiB | 8.54 MiB/s, done.\n",
      "Resolving deltas: 100% (114349/114349), done.\n",
      "HEAD is now at 8c5eaa1a1 Merge branch 'RELEASE-4.0' of github.com:jowagner/mosesdecoder\n",
      "/workspace/gnmt_v2/doc\n",
      "Tokenizing data/wmt16_de_en/newstest2009.de...\n",
      "Tokenizing data/wmt16_de_en/newstest2010.de...\n",
      "Tokenizing data/wmt16_de_en/newstest2011.de...\n",
      "Tokenizing data/wmt16_de_en/newstest2012.de...\n",
      "Tokenizing data/wmt16_de_en/newstest2013.de...\n",
      "Tokenizing data/wmt16_de_en/newstest2014.de...\n",
      "Tokenizing data/wmt16_de_en/newstest2015.de...\n",
      "Tokenizing data/wmt16_de_en/newstest2016.de...\n",
      "Tokenizing data/wmt16_de_en/train.de...\n",
      "^C\n"
     ]
    }
   ],
   "source": [
    "!bash ./scripts/wmt16_en_de.sh"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 MindRecord数据转换\n",
    "#### 2.2.1 Tokenizer\n",
    "定义`Tokenizer`分词器，这里`Tokenizer`分词器首先构建词表(Vocab list)，得到两个token和id，id和token相互对应的字典，然后利用`sacremoses`包分别实现`tokenize()`方法和`detokenize()`方法，也就是将英语和德语相互转换。\n",
    "\n",
    "对于`tokenize()`方法，有个很重要的点就是BPE(Byte Pair Encoding)算法。\n",
    "\n",
    "\n",
    "BPE首次在论文Neural Machine Translation of Rare Words with Subword Units(Sennrich et al., 2015)中被提出。BPE首先需要依赖一个可以预先将训练数据切分成单词的tokenizer，它们可以一些简单的基于空格的tokenizer，如GPT-2，Roberta等；也可以是一些更加复杂的、增加了一些规则的tokenizer，如XLM、FlauBERT。\n",
    "\n",
    "在使用了这些tokenizer后，我们可以得到一个在训练数据中出现过的单词的集合以及它们对应的频数。下一步，BPE使用这个集合中的所有符号（将单词拆分为字母）创建一个基本词表，然后学习合并规则以将基本词表的两个符号形成一个新符号，从而实现对基本词表的更新。它将持续这一操作，直到词表的大小达到了预置的规模。值得注意的是，这个预置的词表大小是一个超参数，需要提前指定。\n",
    "\n",
    "举个例子，假设经过预先切分后，单词及对应的频数如下：\n",
    "\n",
    "`(\"hug\", 10), (\"pug\", 5), (\"pun\", 12), (\"bun\", 4), (\"hugs\", 5)`\n",
    "\n",
    "因此，基本词表的内容为[“b”, “g”, “h”, “n”, “p”, “s”, “u”]。对应的，将所有的单词按照基本词表中的字母拆分，得到：\n",
    "\n",
    "`(\"h\" \"u\" \"g\", 10), (\"p\" \"u\" \"g\", 5), (\"p\" \"u\" \"n\", 12), (\"b\" \"u\" \"n\", 4), (\"h\" \"u\" \"g\" \"s\", 5)`\n",
    "\n",
    "接下来，BPE计算任意两个字母（符号）拼接到一起时，出现在语料中的频数，然后选择频数最大的字母（符号）对。接上例，\"hu\"组合的频数为15（\"hug\"出现了10次，“hugs\"中出现了5次）。在上面的例子中，频数最高的符号对是\"ug”，一共有20次。因此，tokenizer学习到的第一个合并规则就是将所有的\"ug\"合并到一起。于是，基本词表变为：\n",
    "\n",
    "`(\"h\" \"ug\", 10), (\"p\" \"ug\", 5), (\"p\" \"u\" \"n\", 12), (\"b\" \"u\" \"n\", 4), (\"h\" \"ug\" \"s\", 5)`\n",
    "\n",
    "应用相同的算法，下一个频数最高的组合是\"un\"，出现了16次，于是\"un\"被添加到词表中；接下来是\"hug\"，即\"h\"与第一步得到的\"ug\"组合的频数最高，共有15次，于是\"hug\"被添加到了词表中。\n",
    "\n",
    "此时，词表的内容为[“b”, “g”, “h”, “n”, “p”, “s”, “u”, “ug”, “un”, “hug”]，原始的单词按照词表拆分后的内容如下：\n",
    "`(\"hug\", 10), (\"p\" \"ug\", 5), (\"p\" \"un\", 12), (\"b\" \"un\", 4), (\"hug\" \"s\", 5)`\n",
    "\n",
    "假定BPE的训练到这一步就停止，接下来就是利用它学习到的这些规则来切分新的单词（只要新单词中没有超出基本词表之外的符号）。例如，单词\"bug\"将会被切分为[“b”, “ug”]，但是单词\"mug\"将会被切分为[\"\", “ug”]——这是因为\"m\"不在词表中。\n",
    "\n",
    "正如之前提到的，词表的规模——也就是基本词表的大小加上合并后的单词的数量——是一个超参数。例如，对于GPT而言，其词表的大小为40,478，其中，基本字符一共有478个，合并后的词有40,000个。.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from collections import defaultdict\n",
    "from functools import partial\n",
    "import subword_nmt.apply_bpe\n",
    "import sacremoses\n",
    "\n",
    "class Tokenizer:\n",
    "    \"\"\"\n",
    "    Tokenizer class.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, vocab_address=None, bpe_code_address=None,\n",
    "                 src_en='en', tgt_de='de', vocab_pad=8, isolator='@@'):\n",
    "        \"\"\"\n",
    "        构造Tokenizer class.\n",
    "\n",
    "        Args:\n",
    "            vocab_address: vocabulary address.\n",
    "            bpe_code_address: path to the file with bpe codes.\n",
    "            vocab_pad: pads vocabulary to a multiple of 'vocab_pad' tokens.\n",
    "            isolator: tokenization isolator.\n",
    "        \"\"\"\n",
    "        self.padding_index = 0\n",
    "        self.unk_index = 1\n",
    "        self.bos_index = 2\n",
    "        self.eos_index = 3\n",
    "        self.pad_word = '<pad>'\n",
    "        self.unk_word = '<unk>'\n",
    "        self.bos_word = '<s>'\n",
    "        self.eos_word = r'<\\s>'\n",
    "        self.isolator = isolator\n",
    "        self.idx2tok = {}\n",
    "        \n",
    "        self.init_bpe(bpe_code_address)\n",
    "        \n",
    "        self.vocab_establist(vocab_address, vocab_pad)\n",
    "        self.sacremoses_tokenizer = sacremoses.MosesTokenizer(src_en)\n",
    "        self.sacremoses_detokenizer = sacremoses.MosesDetokenizer(tgt_de)\n",
    "\n",
    "    def init_bpe(self, bpe_code_address):\n",
    "        \"\"\"初始化 bpe.\"\"\"\n",
    "        if (bpe_code_address is not None) and os.path.exists(bpe_code_address):\n",
    "            with open(bpe_code_address, 'r') as f1:\n",
    "                self.bpe = subword_nmt.apply_bpe.BPE(f1)\n",
    "\n",
    "    def vocab_establist(self, vocab_address, vocab_pad):\n",
    "        \"\"\"构建 vocabulary.\"\"\"\n",
    "        if (vocab_address is None) or (not os.path.exists(vocab_address)):\n",
    "            return\n",
    "        vocab_words = [self.pad_word, self.unk_word, self.bos_word, self.eos_word]\n",
    "        with open(vocab_address) as f1:\n",
    "            for sentence in f1:\n",
    "                vocab_words.append(sentence.strip())\n",
    "        vocab_size = len(vocab_words)\n",
    "        padded_vocab_size = (vocab_size + vocab_pad - 1) // vocab_pad * vocab_pad\n",
    "        for idx in range(0, padded_vocab_size - vocab_size):\n",
    "            fil_token = f'filled{idx:04d}'\n",
    "            vocab_words.append(fil_token)\n",
    "        self.vocab_size = len(vocab_words)\n",
    "        self.tok2idx = defaultdict(partial(int, self.unk_index))\n",
    "        for idx, token in enumerate(vocab_words):\n",
    "            self.tok2idx[token] = idx\n",
    "        \n",
    "        self.idx2tok = defaultdict(partial(str, \",\"))\n",
    "        for token, idx in self.tok2idx.items():\n",
    "            self.idx2tok[idx] = token\n",
    "        \n",
    "\n",
    "    def tokenize(self, sentence):\n",
    "        \"\"\"对句子分词\"\"\"\n",
    "        tokenized = self.sacremoses_tokenizer.tokenize(sentence, return_str=True)\n",
    "        bpe = self.bpe.process_line(tokenized)\n",
    "        sentence = bpe.strip().split()\n",
    "        inputs = [self.tok2idx[i] for i in sentence]\n",
    "        inputs = [self.bos_index] + inputs + [self.eos_index]\n",
    "        return inputs\n",
    "\n",
    "    def detokenize(self, indexes, gap=' '):\n",
    "        \"\"\"Detokenizes single sentence and removes token isolator characters.\"\"\"\n",
    "        reconstruction_bpe = gap.join([self.idx2tok[idx] for idx in indexes])\n",
    "        reconstruction_bpe = reconstruction_bpe.replace(self.isolator + ' ', '')\n",
    "        reconstruction_bpe = reconstruction_bpe.replace(self.isolator, '')\n",
    "        reconstruction_bpe = reconstruction_bpe.replace(self.bos_word, '')\n",
    "        reconstruction_bpe = reconstruction_bpe.replace(self.eos_word, '')\n",
    "        reconstruction_bpe = reconstruction_bpe.replace(self.unk_word, '')\n",
    "        reconstruction_bpe = reconstruction_bpe.replace(self.pad_word, '')\n",
    "        reconstruction_bpe = reconstruction_bpe.strip()\n",
    "        reconstruction_words = self.sacremoses_detokenizer.detokenize(reconstruction_bpe.split())\n",
    "        return reconstruction_words"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2.2.2 DataLoader\n",
    "定义`DataLoader`类，该类有`padding()`方法，当句子长度不够时，可以为句子添加`<pad>`特殊填充符号。\n",
    "\n",
    "实现了`write_to_mindrecord()`方法，该方法可以将给定指定`Schema`类型的数据写入到`MindRecord`格式的数据集中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from mindspore.mindrecord import FileWriter\n",
    "\n",
    "SCHEMA = {\n",
    "    \"src\": {\"type\": \"int64\", \"shape\": [-1]},\n",
    "    \"src_padding\": {\"type\": \"int64\", \"shape\": [-1]},\n",
    "    \"prev_opt\": {\"type\": \"int64\", \"shape\": [-1]},\n",
    "    \"target\": {\"type\": \"int64\", \"shape\": [-1]},\n",
    "    \"tgt_padding\": {\"type\": \"int64\", \"shape\": [-1]},\n",
    "}\n",
    "\n",
    "TEST_SCHEMA = {\n",
    "    \"src\": {\"type\": \"int64\", \"shape\": [-1]},\n",
    "    \"src_padding\": {\"type\": \"int64\", \"shape\": [-1]},\n",
    "}\n",
    "\n",
    "class DataLoader:\n",
    "    \"\"\"定义数据集的数据加载器。\"\"\"\n",
    "    _SCHEMA = SCHEMA\n",
    "    _TEST_SCHEMA = TEST_SCHEMA\n",
    "\n",
    "    def __init__(self):\n",
    "        self._examples = []\n",
    "\n",
    "    def _load(self):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def padding(self, sen, padding_idx, need_sentence_len=None, dtype=np.int64):\n",
    "        \"\"\"为句子填充<pad>\"\"\"\n",
    "        if need_sentence_len is None:\n",
    "            return None\n",
    "        if sen.shape[0] > need_sentence_len:\n",
    "            return None\n",
    "        new_sen = np.array([padding_idx] * need_sentence_len, dtype=dtype)\n",
    "        new_sen[:sen.shape[0]] = sen[:]\n",
    "        return new_sen\n",
    "\n",
    "    def write_to_mindrecord(self, path, train_mode, shard_num=1, desc=\"gnmt\"):\n",
    "        \"\"\"\n",
    "        将内置examples数据写入到 mindrecord 文件\n",
    "\n",
    "        Args:\n",
    "            path (str): File path.\n",
    "            shard_num (int): Shard num.\n",
    "            desc (str): Description.\n",
    "        \"\"\"\n",
    "        if not os.path.isabs(path):\n",
    "            path = os.path.abspath(path)\n",
    "\n",
    "        writer = FileWriter(file_name=path, shard_num=shard_num)\n",
    "        if train_mode:\n",
    "            writer.add_schema(self._SCHEMA, desc)\n",
    "        else:\n",
    "            writer.add_schema(self._TEST_SCHEMA, desc)\n",
    "        if not self._examples:\n",
    "            self._load()\n",
    "\n",
    "        writer.write_raw_data(self._examples)\n",
    "        writer.commit()\n",
    "        print(f\"| Wrote to {path}.\")\n",
    "\n",
    "    def _add_example(self, example):\n",
    "        self._examples.append(example)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2.2.3 TextDataLoader\n",
    "定义`TextDataLoader`类用于将单语种的测试语料转换为`MindRecord`文件，这里需要对数据语料简单处理。\n",
    "\n",
    "首先将测试数据集中文本语料通过`Tokenizer`分词器将文本数据转换成词向量，同时为每一句词向量添加开始符(bos_index)和结束符(eod_index)，去掉长度太长的句子得到输入和输出双语向量，完成数据加载。\n",
    "\n",
    "最后还需要为`MindRecord`中写入的类型格式，制作格式`Schema`的Json文件，方便后续直接根据指定的数据类型读取`MindRecord`数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class TextDataLoader(DataLoader):\n",
    "    \"\"\"定义文本数据的加载器loader\"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 src_filepath: str,\n",
    "                 tokenizer: Tokenizer,\n",
    "                 min_sen_len=0,\n",
    "                 source_max_sen_len=None,\n",
    "                 schema_address=None):\n",
    "        super(TextDataLoader, self).__init__()\n",
    "        self._src_filepath = src_filepath\n",
    "        self.tokenizer = tokenizer\n",
    "        self.min_sen_len = min_sen_len\n",
    "        self.source_max_sen_len = source_max_sen_len\n",
    "        self.schema_address = schema_address\n",
    "\n",
    "    def _load(self):\n",
    "        count = 0\n",
    "        if self.source_max_sen_len is None:\n",
    "            with open(self._src_filepath, \"r\") as _src_file:\n",
    "                print(f\" | count the max_sen_len of corpus {self._src_filepath}.\")\n",
    "                max_src = 0\n",
    "                for _, _pair in enumerate(_src_file):\n",
    "                    src_tokens = self.tokenizer.tokenize(_pair)\n",
    "                    src_len = len(src_tokens)\n",
    "                    if src_len > max_src:\n",
    "                        max_src = src_len\n",
    "                self.source_max_sen_len = max_src\n",
    "\n",
    "        with open(self._src_filepath, \"r\") as _src_file:\n",
    "            print(f\" | Processing corpus {self._src_filepath}.\")\n",
    "            for _, _pair in enumerate(_src_file):\n",
    "                src_tokens = self.tokenizer.tokenize(_pair)\n",
    "                src_len = len(src_tokens)\n",
    "                src_tokens = np.array(src_tokens)\n",
    "                # encoder的输入\n",
    "                encoder_input = self.padding(src_tokens, self.tokenizer.padding_index, self.source_max_sen_len)\n",
    "                src_padding = np.zeros(shape=self.source_max_sen_len, dtype=np.int64)\n",
    "                for i in range(src_len):\n",
    "                    src_padding[i] = 1\n",
    "\n",
    "                example = {\n",
    "                    \"src\": encoder_input,\n",
    "                    \"src_padding\": src_padding\n",
    "                }\n",
    "                self._add_example(example)\n",
    "                count += 1\n",
    "\n",
    "            print(f\" | source padding_len = {self.source_max_sen_len}.\")\n",
    "            print(f\" | Total  activate  sen = {count}.\")\n",
    "            print(f\" | Total  sen = {count}.\")\n",
    "\n",
    "            #编写Schema文件\n",
    "            if self.schema_address is not None:\n",
    "                provlist = [count, self.source_max_sen_len, self.source_max_sen_len]\n",
    "                columns = [\"src\", \"src_padding\"]\n",
    "                with open(self.schema_address, \"w\", encoding=\"utf-8\") as  f:\n",
    "                    f.write(\"{\\n\")\n",
    "                    f.write('  \"datasetType\":\"MS\",\\n')\n",
    "                    f.write('  \"numRows\":%s,\\n' % provlist[0])\n",
    "                    f.write('  \"columns\":{\\n')\n",
    "                    t = 1\n",
    "                    for name in columns:\n",
    "                        f.write('    \"%s\":{\\n' % name)\n",
    "                        f.write('      \"type\":\"int64\",\\n')\n",
    "                        f.write('      \"rank\":1,\\n')\n",
    "                        f.write('      \"shape\":[%s]\\n' % provlist[t])\n",
    "                        f.write('    }')\n",
    "                        if t < len(columns):\n",
    "                            f.write(',')\n",
    "                        f.write('\\n')\n",
    "                        t += 1\n",
    "                    f.write('  }\\n}\\n')\n",
    "                    print(\" | Write to \" + self.schema_address)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2.2.4 BiLingualDataLoader\n",
    "与上述一致，我们这里定义`BiLingualDataLoader`双语数据加载器，将双语的训练数据语料简单处理转换成为`MindRecord`数据格式。\n",
    "\n",
    "首先将双语语料通过`Tokenizer`分词器将器转换成向量，同时为每一句向量添加开始符(bos_index)和结束符(eod_index)，同时去掉长度太长的句子得到输入和输出双语向量，完成数据加载。\n",
    "\n",
    "唯一与上述定义的`TextDataLoader`不一致的地方在于训练数据是双语翻译语料，因此，需要有`src`的源语种和`tgt`的目标语种的区别。\n",
    "\n",
    "同时，由于是GNMT同样也是seq2seq结构，训练阶段需要输入`encoder_input`、`decoder_input`、`decoder_output`。我们这里以`prev_opt`代替`decoder_input`。这里的`decoder_output`就是`prev_opt`的shift一位（如下图）\n",
    "\n",
    "<center>\n",
    "    <img style=\"border-radius: 0.3125em;\n",
    "    box-shadow: 0 2px 4px 0 rgba(34,36,38,.12),0 2px 10px 0 rgba(34,36,38,.08); width:50%; height:50%;\" \n",
    "    src=\"https://img-blog.csdn.net/20171201092713541\">\n",
    "    <br>\n",
    "    <div style=\"color:orange; border-bottom: 1px solid #d9d9d9;\n",
    "    display: inline-block;\n",
    "    color: #999;\n",
    "    padding: 2px;\">图1. decoder_output与prev_opt对应关系</div>\n",
    "</center>\n",
    "\n",
    "\n",
    "同样的最后为`MindRecord`中写入的数据，制作格式`Schema`的Json文件，方便后续直接根据指定的数据类型读取`MindRecord`数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BiLingualDataLoader(DataLoader):\n",
    "    \"\"\"Loader for bilingual data.\"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 src_filepath: str,\n",
    "                 tgt_filepath: str,\n",
    "                 tokenizer: Tokenizer,\n",
    "                 min_sen_len=0,\n",
    "                 source_max_sen_len=None,\n",
    "                 target_max_sen_len=80,\n",
    "                 schema_address=None):\n",
    "        super(BiLingualDataLoader, self).__init__()\n",
    "        self._src_filepath = src_filepath\n",
    "        self._tgt_filepath = tgt_filepath\n",
    "        self.tokenizer = tokenizer\n",
    "        self.min_sen_len = min_sen_len\n",
    "        self.source_max_sen_len = source_max_sen_len\n",
    "        self.target_max_sen_len = target_max_sen_len\n",
    "        self.schema_address = schema_address\n",
    "\n",
    "    def _load(self):\n",
    "        count = 0\n",
    "        if self.source_max_sen_len is None:\n",
    "            with open(self._src_filepath, \"r\") as _src_file:\n",
    "                print(f\" | count the max_sen_len of corpus {self._src_filepath}.\")\n",
    "                max_src = 0\n",
    "                for _, _pair in enumerate(_src_file):\n",
    "                    src_tokens = [\n",
    "                        int(self.tokenizer.tok2idx[t])\n",
    "                        for t in _pair.strip().split(\" \") if t\n",
    "                    ]\n",
    "                    src_len = len(src_tokens)\n",
    "                    if src_len > max_src:\n",
    "                        max_src = src_len\n",
    "                self.source_max_sen_len = max_src + 2\n",
    "\n",
    "        if self.target_max_sen_len is None:\n",
    "            with open(self._src_filepath, \"r\") as _tgt_file:\n",
    "                print(f\" | count the max_sen_len of corpus {self._src_filepath}.\")\n",
    "                max_tgt = 0\n",
    "                for _, _pair in enumerate(_tgt_file):\n",
    "                    src_tokens = [\n",
    "                        int(self.tokenizer.tok2idx[t])\n",
    "                        for t in _pair.strip().split(\" \") if t\n",
    "                    ]\n",
    "                    tgt_len = len(src_tokens)\n",
    "                    if tgt_len > max_tgt:\n",
    "                        max_tgt = tgt_len\n",
    "                self.target_max_sen_len = max_tgt + 1\n",
    "\n",
    "        with open(self._src_filepath, \"r\") as _src_file:\n",
    "            print(f\" | Processing corpus {self._src_filepath}.\")\n",
    "            print(f\" | Processing corpus {self._tgt_filepath}.\")\n",
    "            with open(self._tgt_filepath, \"r\") as _tgt_file:\n",
    "                for _, _pair in enumerate(zip(_src_file, _tgt_file)):\n",
    "\n",
    "                    src_tokens = [\n",
    "                        int(self.tokenizer.tok2idx[t])\n",
    "                        for t in _pair[0].strip().split(\" \") if t\n",
    "                    ]\n",
    "                    tgt_tokens = [\n",
    "                        int(self.tokenizer.tok2idx[t])\n",
    "                        for t in _pair[1].strip().split(\" \") if t\n",
    "                    ]\n",
    "                    src_tokens.insert(0, self.tokenizer.bos_index)\n",
    "                    src_tokens.append(self.tokenizer.eos_index)\n",
    "                    tgt_tokens.insert(0, self.tokenizer.bos_index)\n",
    "                    tgt_tokens.append(self.tokenizer.eos_index)\n",
    "                    src_tokens = np.array(src_tokens)\n",
    "                    tgt_tokens = np.array(tgt_tokens)\n",
    "                    src_len = src_tokens.shape[0]\n",
    "                    tgt_len = tgt_tokens.shape[0]\n",
    "\n",
    "                    if (src_len > self.source_max_sen_len) or (src_len < self.min_sen_len) or (\n",
    "                            tgt_len > (self.target_max_sen_len + 1)) or (tgt_len < self.min_sen_len):\n",
    "                        print(f\"+++++ delete! src_len={src_len}, tgt_len={tgt_len - 1}, \"\n",
    "                              f\"source_max_sen_len={self.source_max_sen_len},\"\n",
    "                              f\"target_max_sen_len={self.target_max_sen_len}\")\n",
    "                        continue\n",
    "                    # encoder inputs\n",
    "                    encoder_input = self.padding(src_tokens, self.tokenizer.padding_index, self.source_max_sen_len)\n",
    "                    src_padding = np.zeros(shape=self.source_max_sen_len, dtype=np.int64)\n",
    "                    for i in range(src_len):\n",
    "                        src_padding[i] = 1\n",
    "                    # decoder inputs\n",
    "                    decoder_input = self.padding(tgt_tokens[:-1], self.tokenizer.padding_index, self.target_max_sen_len)\n",
    "                    # decoder outputs\n",
    "                    decoder_output = self.padding(tgt_tokens[1:], self.tokenizer.padding_index, self.target_max_sen_len)\n",
    "                    tgt_padding = np.zeros(shape=self.target_max_sen_len + 1, dtype=np.int64)\n",
    "                    for j in range(tgt_len):\n",
    "                        tgt_padding[j] = 1\n",
    "                    tgt_padding = tgt_padding[1:]\n",
    "                    decoder_input = np.array(decoder_input, dtype=np.int64)\n",
    "                    decoder_output = np.array(decoder_output, dtype=np.int64)\n",
    "                    tgt_padding = np.array(tgt_padding, dtype=np.int64)\n",
    "\n",
    "                    example = {\n",
    "                        \"src\": encoder_input,\n",
    "                        \"src_padding\": src_padding,\n",
    "                        \"prev_opt\": decoder_input,\n",
    "                        \"target\": decoder_output,\n",
    "                        \"tgt_padding\": tgt_padding\n",
    "                    }\n",
    "                    self._add_example(example)\n",
    "                    count += 1\n",
    "\n",
    "                print(f\" | source padding_len = {self.source_max_sen_len}.\")\n",
    "                print(f\" | target padding_len = {self.target_max_sen_len}.\")\n",
    "                print(f\" | Total  activate  sen = {count}.\")\n",
    "                print(f\" | Total  sen = {count}.\")\n",
    "\n",
    "                if self.schema_address is not None:\n",
    "                    provlist = [count, self.source_max_sen_len, self.source_max_sen_len,\n",
    "                                self.target_max_sen_len, self.target_max_sen_len, self.target_max_sen_len]\n",
    "                    columns = [\"src\", \"src_padding\", \"prev_opt\", \"target\", \"tgt_padding\"]\n",
    "                    with open(self.schema_address, \"w\", encoding=\"utf-8\") as  f:\n",
    "                        f.write(\"{\\n\")\n",
    "                        f.write('  \"datasetType\":\"MS\",\\n')\n",
    "                        f.write('  \"numRows\":%s,\\n' % provlist[0])\n",
    "                        f.write('  \"columns\":{\\n')\n",
    "                        t = 1\n",
    "                        for name in columns:\n",
    "                            f.write('    \"%s\":{\\n' % name)\n",
    "                            f.write('      \"type\":\"int64\",\\n')\n",
    "                            f.write('      \"rank\":1,\\n')\n",
    "                            f.write('      \"shape\":[%s]\\n' % provlist[t])\n",
    "                            f.write('    }')\n",
    "                            if t < len(columns):\n",
    "                                f.write(',')\n",
    "                            f.write('\\n')\n",
    "                            t += 1\n",
    "                        f.write('  }\\n}\\n')\n",
    "                        print(\" | Write to \" + self.schema_address)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2.2.5 数据集转换\n",
    "定义`create_dataset()`方法统一将测试单语种数据集和训练双语种数据集转换成 `MindRecord`数据集格式。\n",
    "\n",
    "首先实例化上述定义的`Tokenizer`分词器，利用`tokenizer`和`vocab`对测试数据集和训练数据集进行词向量转换，最后写入`MindRecord`数据集文件。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " | It's writing, please wait a moment.\n",
      " | count the max_sen_len of corpus ./data/wmt16_de_en/newstest2014.en.\n"
     ]
    },
    {
     "ename": "AttributeError",
     "evalue": "'Tokenizer' object has no attribute 'bpe'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_32035/1846146685.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     47\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     48\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\" | Vocabulary size: {tokenizer.vocab_size}.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0mcreate_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/tmp/ipykernel_32035/1846146685.py\u001b[0m in \u001b[0;36mcreate_dataset\u001b[0;34m()\u001b[0m\n\u001b[1;32m     26\u001b[0m             \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbasename\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_src_file\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\".mindrecord\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m         ),\n\u001b[0;32m---> 28\u001b[0;31m         \u001b[0mtrain_mode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     29\u001b[0m     )\n\u001b[1;32m     30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_32035/2689467905.py\u001b[0m in \u001b[0;36mwrite_to_mindrecord\u001b[0;34m(self, path, train_mode, shard_num, desc)\u001b[0m\n\u001b[1;32m     54\u001b[0m             \u001b[0mwriter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_schema\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_TEST_SCHEMA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdesc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     55\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_examples\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     57\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     58\u001b[0m         \u001b[0mwriter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite_raw_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_examples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_32035/3790664692.py\u001b[0m in \u001b[0;36m_load\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     22\u001b[0m                 \u001b[0mmax_src\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m                 \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_pair\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_src_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m                     \u001b[0msrc_tokens\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtokenize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_pair\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     25\u001b[0m                     \u001b[0msrc_len\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msrc_tokens\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     26\u001b[0m                     \u001b[0;32mif\u001b[0m \u001b[0msrc_len\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mmax_src\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_32035/2009192252.py\u001b[0m in \u001b[0;36mtokenize\u001b[0;34m(self, sentence)\u001b[0m\n\u001b[1;32m     70\u001b[0m         \u001b[0;34m\"\"\"对句子分词\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     71\u001b[0m         \u001b[0mtokenized\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msacremoses_tokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtokenize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msentence\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_str\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 72\u001b[0;31m         \u001b[0mbpe\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_line\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtokenized\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     73\u001b[0m         \u001b[0msentence\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     74\u001b[0m         \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtok2idx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msentence\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'Tokenizer' object has no attribute 'bpe'"
     ]
    }
   ],
   "source": [
    "def create_dataset():\n",
    "    dicts = []\n",
    "    train_src_file = \"train.tok.clean.bpe.32000.en\"\n",
    "    train_tgt_file = \"train.tok.clean.bpe.32000.de\"\n",
    "    test_src_file = \"newstest2014.en\"\n",
    "    test_tgt_file = \"newstest2014.de\"\n",
    "\n",
    "    output_folder = './data/dataset_menu'\n",
    "    src_folder = './data/wmt16_de_en'\n",
    "\n",
    "    vocab = \"./data/wmt16_de_en/vocab.bpe.32000\"\n",
    "    bpe_codes = \"./data/wmt16_de_en/bpe.32000\"\n",
    "    pad_vocab = 8\n",
    "    tokenizer = Tokenizer(vocab, bpe_codes, src_en='en', tgt_de='de', vocab_pad=pad_vocab)\n",
    "\n",
    "    test = TextDataLoader(\n",
    "        src_filepath=os.path.join(src_folder, test_src_file),\n",
    "        tokenizer=tokenizer,\n",
    "        source_max_sen_len=None,\n",
    "        schema_address=output_folder + \"/\" + test_src_file + \".json\"\n",
    "    )\n",
    "    print(f\" | It's writing, please wait a moment.\")\n",
    "    test.write_to_mindrecord(\n",
    "        path=os.path.join(\n",
    "            output_folder,\n",
    "            os.path.basename(test_src_file) + \".mindrecord\"\n",
    "        ),\n",
    "        train_mode=False\n",
    "    )\n",
    "\n",
    "    train = BiLingualDataLoader(\n",
    "        src_filepath=os.path.join(src_folder, train_src_file),\n",
    "        tgt_filepath=os.path.join(src_folder, train_tgt_file),\n",
    "        tokenizer=tokenizer,\n",
    "        source_max_sen_len=51,\n",
    "        target_max_sen_len=50,\n",
    "        schema_address=output_folder + \"/\" + train_src_file + \".json\"\n",
    "    )\n",
    "    print(f\" | It's writing, please wait a moment.\")\n",
    "    train.write_to_mindrecord(\n",
    "        path=os.path.join(\n",
    "            output_folder,\n",
    "            os.path.basename(train_src_file) + \".mindrecord\"\n",
    "        ),\n",
    "        train_mode=True\n",
    "    )\n",
    "\n",
    "    print(f\" | Vocabulary size: {tokenizer.vocab_size}.\")\n",
    "create_dataset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3 数据加载\n",
    "在将处理后的文本数据转换成词向量并按需要的格式写入`MindRecord`文件后，我们还需要对数据文件进行加载。\n",
    "\n",
    "首先定义数据加载方法，使用`mindspore.dataset`中的`MindDataset`读取之前转换的`MindRecord`文件，根据列名读取文件后转换为划分`batch`的`dataset`数据集类实例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.common.dtype as mstype\n",
    "import mindspore.dataset as ds\n",
    "import mindspore.dataset.transforms as deC\n",
    "\n",
    "def _load_dataset(input_files, batch_size, sink_mode=False,\n",
    "                  rank_size=1, rank_id=0, shuffle=True, drop_remainder=True,\n",
    "                  is_translate=False):\n",
    "    \"\"\"\n",
    "    根据传入的参数加载数据集。\n",
    "\n",
    "    Args:\n",
    "        input_files (list): Data files.\n",
    "        batch_size (int): Batch size.\n",
    "        sink_mode (bool): Whether enable sink mode.\n",
    "        rank_size (int): Rank size.\n",
    "        rank_id (int): Rank id.\n",
    "        shuffle (bool): Whether shuffle dataset.\n",
    "        drop_remainder (bool): Whether drop the last possibly incomplete batch.\n",
    "        is_translate (bool): Whether translate the text.\n",
    "\n",
    "    Returns:\n",
    "        Dataset, dataset instance.\n",
    "    \"\"\"\n",
    "    if not input_files:\n",
    "        raise FileNotFoundError(\"Require at least one dataset.\")\n",
    "\n",
    "    if not isinstance(sink_mode, bool):\n",
    "        raise ValueError(\"`sink` must be type of bool.\")\n",
    "\n",
    "    for datafile in input_files:\n",
    "        print(f\" | Loading {datafile}.\")\n",
    "\n",
    "    if not is_translate:\n",
    "        data_set = ds.MindDataset(\n",
    "            input_files, columns_list=[\n",
    "                \"src\", \"src_padding\",\n",
    "                \"prev_opt\",\n",
    "                \"target\", \"tgt_padding\"\n",
    "            ], shuffle=False, num_shards=rank_size, shard_id=rank_id,\n",
    "            num_parallel_workers=8\n",
    "        )\n",
    "\n",
    "        ori_dataset_size = data_set.get_dataset_size()\n",
    "        print(f\" | Dataset size: {ori_dataset_size}.\")\n",
    "        if shuffle:\n",
    "            data_set = data_set.shuffle(buffer_size=ori_dataset_size // 20)\n",
    "        type_cast_op = deC.TypeCast(mstype.int32)\n",
    "        data_set = data_set.map(input_columns=\"src\", operations=type_cast_op, num_parallel_workers=8)\n",
    "        data_set = data_set.map(input_columns=\"src_padding\", operations=type_cast_op, num_parallel_workers=8)\n",
    "        data_set = data_set.map(input_columns=\"prev_opt\", operations=type_cast_op, num_parallel_workers=8)\n",
    "        data_set = data_set.map(input_columns=\"target\", operations=type_cast_op, num_parallel_workers=8)\n",
    "        data_set = data_set.map(input_columns=\"tgt_padding\", operations=type_cast_op, num_parallel_workers=8)\n",
    "\n",
    "        data_set = data_set.rename(\n",
    "            input_columns=[\"src\",\n",
    "                           \"src_padding\",\n",
    "                           \"prev_opt\",\n",
    "                           \"target\",\n",
    "                           \"tgt_padding\"],\n",
    "            output_columns=[\"source_eos_ids\",\n",
    "                            \"source_eos_mask\",\n",
    "                            \"target_sos_ids\",\n",
    "                            \"target_eos_ids\",\n",
    "                            \"target_eos_mask\"]\n",
    "        )\n",
    "        data_set = data_set.batch(batch_size, drop_remainder=drop_remainder)\n",
    "    else:\n",
    "        data_set = ds.MindDataset(\n",
    "            input_files, columns_list=[\n",
    "                \"src\", \"src_padding\"\n",
    "            ],\n",
    "            shuffle=False, num_shards=rank_size, shard_id=rank_id,\n",
    "            num_parallel_workers=8\n",
    "        )\n",
    "\n",
    "        ori_dataset_size = data_set.get_dataset_size()\n",
    "        print(f\" | Dataset size: {ori_dataset_size}.\")\n",
    "        if shuffle:\n",
    "            data_set = data_set.shuffle(buffer_size=ori_dataset_size // 20)\n",
    "        type_cast_op = deC.TypeCast(mstype.int32)\n",
    "        data_set = data_set.map(input_columns=\"src\", operations=type_cast_op, num_parallel_workers=8)\n",
    "        data_set = data_set.map(input_columns=\"src_padding\", operations=type_cast_op, num_parallel_workers=8)\n",
    "\n",
    "        data_set = data_set.rename(\n",
    "            input_columns=[\"src\",\n",
    "                           \"src_padding\"],\n",
    "            output_columns=[\"source_eos_ids\",\n",
    "                            \"source_eos_mask\"]\n",
    "        )\n",
    "        data_set = data_set.batch(batch_size, drop_remainder=drop_remainder)\n",
    "\n",
    "    return data_set\n",
    "\n",
    "\n",
    "def load_dataset(data_files: list, batch_size: int, sink_mode: bool,\n",
    "                 rank_size: int = 1, rank_id: int = 0, shuffle=True, drop_remainder=True, is_translate=False):\n",
    "    \"\"\"\n",
    "    定义加载数据集的方法.\n",
    "\n",
    "    Args:\n",
    "        data_files (list): Data files.\n",
    "        batch_size (int): Batch size.\n",
    "        sink_mode (bool): Whether enable sink mode.\n",
    "        rank_size (int): Rank size.\n",
    "        rank_id (int): Rank id.\n",
    "        shuffle (bool): Whether shuffle dataset.\n",
    "\n",
    "    Returns:\n",
    "        Dataset, dataset instance.\n",
    "    \"\"\"\n",
    "    return _load_dataset(data_files, batch_size, sink_mode, rank_size, rank_id, shuffle=shuffle,\n",
    "                         drop_remainder=drop_remainder, is_translate=is_translate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.4 config文件加载配置\n",
    "由于在模型训练、评估过程中参数较多，因此我们将所有参数直接封装成单独的参数类，方便后续直接调用。\n",
    "\n",
    "定义`get_config()`函数统一配置如训练平台、配置参数和优化器参数等参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def get_config(config):\n",
    "    '''\n",
    "    获取config参数\n",
    "    '''\n",
    "    return config\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.模型构建\n",
    "### 3.1 Embedding层\n",
    "定义`EmbeddingLookup`类，产生词嵌入表。\n",
    "\n",
    "在这里有一个关键参数是`use_one_hot_embeddings`。当设置`use_one_hot_embeddings`参数为True时，则会在生成embedding时先生成一个对应的onehot张量，并用onehot 张量与embedding table相乘最终获得对应的embedding张量。而当`use_one_hot_embeddings`参数设置为False时，则会直接利用mindspore.ops.Gather()方法在embedding table中将对应的embedding提取出来，组合成为对应的embedding张量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import mindspore.common.dtype as mstype\n",
    "from mindspore import nn\n",
    "from mindspore.ops import operations as P\n",
    "from mindspore.common.tensor import Tensor\n",
    "from mindspore.common.parameter import Parameter\n",
    "\n",
    "class EmbeddingLookup(nn.Cell):\n",
    "    \"\"\"\n",
    "    构建一个具有固定词典和size的词嵌入查找表\n",
    "\n",
    "    Args:\n",
    "        is_training (bool): Whether to train.\n",
    "        vocab_size (int): Size of the dictionary of embeddings.\n",
    "        embed_dim (int): The size of word embedding.\n",
    "        initializer_range (int): The initialize range of parameters.\n",
    "        use_one_hot_embeddings (bool): Whether use one-hot embedding. Default: False.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 is_training,\n",
    "                 vocab_size,\n",
    "                 embed_dim,\n",
    "                 initializer_range=0.1,\n",
    "                 use_one_hot_embeddings=False):\n",
    "\n",
    "        super(EmbeddingLookup, self).__init__()\n",
    "        self.is_training = is_training\n",
    "        self.embedding_dim = embed_dim\n",
    "        self.vocab_size = vocab_size\n",
    "        self.use_one_hot_embeddings = use_one_hot_embeddings\n",
    "\n",
    "        init_weight = np.random.normal(-initializer_range, initializer_range, size=[vocab_size, embed_dim])\n",
    "        self.embedding_table = Parameter(Tensor(init_weight, mstype.float32))\n",
    "        self.expand = P.ExpandDims()\n",
    "        self.gather = P.Gather()\n",
    "        self.one_hot = P.OneHot()\n",
    "        self.on_value = Tensor(1.0, mstype.float32)\n",
    "        self.off_value = Tensor(0.0, mstype.float32)\n",
    "        self.array_mul = P.MatMul()\n",
    "        self.reshape = P.Reshape()\n",
    "        self.get_shape = P.Shape()\n",
    "        self.cast = P.Cast()\n",
    "\n",
    "    def construct(self, input_ids):\n",
    "        \"\"\"\n",
    "        构建网络\n",
    "\n",
    "        Args:\n",
    "            input_ids (Tensor): A batch of sentences with shape (N, T).\n",
    "\n",
    "        Returns:\n",
    "            Tensor, word embeddings with shape (N, T, D)\n",
    "        \"\"\"\n",
    "        _shape = self.get_shape(input_ids)  # (N, T).\n",
    "        _batch_size = _shape[0]\n",
    "        _max_len = _shape[1]\n",
    "        if self.is_training:\n",
    "            embedding_table = self.cast(self.embedding_table, mstype.float16)\n",
    "        else:\n",
    "            embedding_table = self.embedding_table\n",
    "\n",
    "        flat_ids = self.reshape(input_ids, (_batch_size * _max_len,))\n",
    "        if self.use_one_hot_embeddings:\n",
    "            one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)\n",
    "            if self.is_training:\n",
    "                one_hot_ids = self.cast(one_hot_ids, mstype.float16)\n",
    "            output_for_reshape = self.array_mul(one_hot_ids, embedding_table)\n",
    "        else:\n",
    "            output_for_reshape = self.gather(embedding_table, flat_ids, 0)\n",
    "\n",
    "        output = self.reshape(output_for_reshape, (_batch_size, _max_len, self.embedding_dim))\n",
    "        if self.is_training:\n",
    "            output = self.cast(output, mstype.float32)\n",
    "            embedding_table = self.cast(embedding_table, mstype.float32)\n",
    "        return output, embedding_table\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 Encoder\n",
    "#### 3.2.1 DynamicRNNCell\n",
    "定义`DynamicRNNCell`类，该类中首先定义了$w$权重参数和$b$偏置参数。\n",
    "\n",
    "随后根据当前运行平台的不同，初始化不同的循环神经网络。在华为的`ascend`平台，使用的是`mindspore.ops`中的`DynamicRNN`；在其他平台上，使用的是传统的`mindspore.nn`中的`LSTM`。\n",
    "\n",
    "同时需要对不同平台的数据格式进行转换，所以也使用了`ops`中的`Cast`类，能够将数据根据不同平台转换为不同计算类型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore import context\n",
    "\n",
    "class DynamicRNNCell(nn.Cell):\n",
    "    \"\"\"\n",
    "    动态RNN网络 Cell.\n",
    "\n",
    "    Args:\n",
    "        num_setp (int): Lengths of setences.\n",
    "        batch_size (int): Batch size.\n",
    "        word_embed_dim (int): Input size.\n",
    "        hidden_size (int): Hidden size .\n",
    "        initializer_range (float): Initial range. Default: 0.02\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 num_setp=50,\n",
    "                 batch_size=128,\n",
    "                 word_embed_dim=1024,\n",
    "                 hidden_size=1024,\n",
    "                 initializer_range=0.1):\n",
    "        super(DynamicRNNCell, self).__init__()\n",
    "        self.num_step = num_setp\n",
    "        self.batch_size = batch_size\n",
    "        self.input_size = word_embed_dim\n",
    "        self.hidden_size = hidden_size\n",
    "        # w\n",
    "        dynamicRNN_w = np.random.uniform(-initializer_range, initializer_range,\n",
    "                                         size=[self.input_size + self.hidden_size, 4 * self.hidden_size])\n",
    "        self.dynamicRNN_w = Parameter(Tensor(dynamicRNN_w, mstype.float32))\n",
    "        # b\n",
    "        dynamicRNN_b = np.random.uniform(-initializer_range, initializer_range, size=[4 * self.hidden_size])\n",
    "        self.dynamicRNN_b = Parameter(Tensor(dynamicRNN_b, mstype.float32))\n",
    "\n",
    "        self.dynamicRNN_h = Tensor(np.zeros((1, self.batch_size, self.hidden_size)), mstype.float32)\n",
    "        self.dynamicRNN_c = Tensor(np.zeros((1, self.batch_size, self.hidden_size)), mstype.float32)\n",
    "        self.cast = P.Cast()\n",
    "        self.is_ascend = context.get_context(\"device_target\") == \"Ascend\"\n",
    "        if self.is_ascend:\n",
    "            self.compute_type = mstype.float16\n",
    "            self.rnn = P.DynamicRNN()\n",
    "        else:\n",
    "            self.compute_type = mstype.float32\n",
    "            self.lstm = nn.LSTM(self.input_size,\n",
    "                                self.hidden_size,\n",
    "                                num_layers=1,\n",
    "                                has_bias=True,\n",
    "                                batch_first=False,\n",
    "                                dropout=0.0,\n",
    "                                bidirectional=False)\n",
    "\n",
    "    def construct(self, x, init_h=None, init_c=None):\n",
    "        \"\"\"动态RNN网络，区分GPU和Ascend平台\"\"\"\n",
    "        if init_h is None or init_c is None:\n",
    "            init_h = self.cast(self.dynamicRNN_h, self.compute_type)\n",
    "            init_c = self.cast(self.dynamicRNN_c, self.compute_type)\n",
    "        if self.is_ascend:\n",
    "            w = self.cast(self.dynamicRNN_w, self.compute_type)\n",
    "            b = self.cast(self.dynamicRNN_b, self.compute_type)\n",
    "            output, hn, cn, _, _, _, _, _ = self.rnn(x, w, b, None, init_h, init_c)\n",
    "        else:\n",
    "            output, (hn, cn) = self.lstm(x, (init_h, init_c))\n",
    "        return output, hn, cn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.2.2 DynamicRNNNet\n",
    "定义`DynamicRNNNet`类，该类封装了上述定义的`DynamicRNNCell`类，实现完整RNN功能。\n",
    "\n",
    "本质上华为Ascend昇腾AI处理器属于NPU(Neural-Network Processing Units)，与GPU运算精度和运算方式有所差异，因此这里同样的将两种平台进行区分。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DynamicRNNNet(nn.Cell):\n",
    "    \"\"\"\n",
    "    DynamicRNN Network.\n",
    "\n",
    "    Args:\n",
    "        seq_length (int): Lengths of setences.\n",
    "        batchsize (int): Batch size.\n",
    "        word_embed_dim (int): Input size.\n",
    "        hidden_size (int): Hidden size.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 seq_length=80,\n",
    "                 batchsize=128,\n",
    "                 word_embed_dim=1024,\n",
    "                 hidden_size=1024):\n",
    "        super(DynamicRNNNet, self).__init__()\n",
    "        self.max_length = seq_length\n",
    "        self.hidden_size = hidden_size\n",
    "        self.cast = P.Cast()\n",
    "        self.concat = P.Concat(axis=0)\n",
    "        self.get_shape = P.Shape()\n",
    "        self.net = DynamicRNNCell(num_setp=seq_length,\n",
    "                                  batch_size=batchsize,\n",
    "                                  word_embed_dim=word_embed_dim,\n",
    "                                  hidden_size=hidden_size)\n",
    "        self.is_ascend = context.get_context(\"device_target\") == \"Ascend\"\n",
    "        if self.is_ascend:\n",
    "            self.compute_type = mstype.float16\n",
    "        else:\n",
    "            self.compute_type = mstype.float32\n",
    "\n",
    "    def construct(self, inputs, init_state=None):\n",
    "        \"\"\"动态RNN网络。\"\"\"\n",
    "        inputs = self.cast(inputs, self.compute_type)\n",
    "        if init_state is not None:\n",
    "            init_h = self.cast(init_state[0:1, :, :], self.compute_type)\n",
    "            init_c = self.cast(init_state[-1:, :, :], self.compute_type)\n",
    "            out, state_h, state_c = self.net(inputs, init_h, init_c)\n",
    "        else:\n",
    "            out, state_h, state_c = self.net(inputs)\n",
    "        out = self.cast(out, mstype.float32)\n",
    "        state = self.concat((state_h[-1:, :, :], state_c[-1:, :, :]))\n",
    "        state = self.cast(state, mstype.float32)\n",
    "        # out:[T,b,D], state:[2,b,D]\n",
    "        return out, state\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.2.3 GNMTEncoder\n",
    "\n",
    "多层堆叠的LSTM网络通常会比层数少的网络有更好的性能，然而，简单的错层堆叠会造成训练的缓慢以及容易受到剃度爆炸或梯度消失的影响，在实验中，简单堆叠在4层工作良好，6层简单堆叠性能还好的网络很少见，8层的就更罕见了，为了解决这个问题，在模型中引入了残差连接。\n",
    "\n",
    "通过上述定义的`DynamicRNNNet`实现在`Ascend`和`GPU`上都可以运行的`GNMTEncoder`。\n",
    "\n",
    "本文实现的GNMT v2的Encoder相对于GNMT有所改进，本文构造的GNMT v2的Encoder结构如下：\n",
    "- 总共4层LSTM，每层的隐藏向量大小1024，第一层为双向LSTM，其余为单向LSTM。\n",
    "- 从第三层开始加上残差连接。\n",
    "- 所有LSTM层的输入都应用dropout，dropout概率设置为0.2。\n",
    "- LSTM层的隐藏状态初始化为零。\n",
    "- LSTM层的`weight`和`bias`用均匀分布`(-0.1，0.1)`初始化。\n",
    "\n",
    "模型构建如下：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore import nn\n",
    "from mindspore.ops import operations as P\n",
    "from mindspore.common import dtype as mstype\n",
    "\n",
    "\n",
    "class GNMTEncoder(nn.Cell):\n",
    "    \"\"\"\n",
    "    GNMT encoder的实现部分.\n",
    "\n",
    "    Args:\n",
    "        config: Configuration of GNMT network.\n",
    "        is_training (bool): Whether to train.\n",
    "        compute_type (mstype): Mindspore data type.\n",
    "\n",
    "    Returns:\n",
    "        Tensor, shape of (N, T, D).\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 config,\n",
    "                 is_training: bool,\n",
    "                 compute_type=mstype.float32):\n",
    "        super(GNMTEncoder, self).__init__()\n",
    "        self.input_mask_from_dataset = config.input_mask_from_dataset\n",
    "        self.max_positions = config.seq_length\n",
    "        self.attn_embed_dim = config.hidden_size\n",
    "\n",
    "        self.num_layers = config.num_hidden_layers\n",
    "        self.hidden_dropout_prob = config.hidden_dropout_prob\n",
    "        self.vocab_size = config.vocab_size\n",
    "        self.seq_length = config.seq_length\n",
    "        self.batch_size = config.batch_size\n",
    "        self.word_embed_dim = config.hidden_size\n",
    "\n",
    "        self.transpose = P.Transpose()\n",
    "        self.transpose_orders = (1, 0, 2)\n",
    "        self.reshape = P.Reshape()\n",
    "        self.concat = P.Concat(axis=-1)\n",
    "        encoder_layers = []\n",
    "        for i in range(0, self.num_layers + 1):\n",
    "            if i == 2:\n",
    "                # the bidirectional layer's output is [T,D,2N]\n",
    "                scaler = 2\n",
    "            else:\n",
    "                # the rest layer's output is [T,D,N]\n",
    "                scaler = 1\n",
    "            layer = DynamicRNNNet(seq_length=self.seq_length,\n",
    "                                  batchsize=self.batch_size,\n",
    "                                  word_embed_dim=scaler * self.word_embed_dim,\n",
    "                                  hidden_size=self.word_embed_dim)\n",
    "            encoder_layers.append(layer)\n",
    "        self.encoder_layers = nn.CellList(encoder_layers)\n",
    "        self.reverse_v2 = P.ReverseV2(axis=[0])\n",
    "        self.dropout = nn.Dropout(keep_prob=1.0 - config.hidden_dropout_prob)\n",
    "\n",
    "    def construct(self, inputs):\n",
    "        \"\"\"Encoder.\"\"\"\n",
    "        inputs = self.dropout(inputs)\n",
    "        # bidirectional layer, fwd_encoder_outputs: [T,N,D]\n",
    "        fwd_encoder_outputs, _ = self.encoder_layers[0](inputs)\n",
    "\n",
    "        # the input need reverse.\n",
    "        inputs_r = self.reverse_v2(inputs)\n",
    "        bak_encoder_outputs, _ = self.encoder_layers[1](inputs_r)\n",
    "        # the result need reverse.\n",
    "        bak_encoder_outputs = self.reverse_v2(bak_encoder_outputs)\n",
    "\n",
    "        # bi_encoder_outputs: [T,N,2D]\n",
    "        bi_encoder_outputs = self.concat((fwd_encoder_outputs, bak_encoder_outputs))\n",
    "\n",
    "        # 1st unidirectional layer. encoder_outputs: [T,N,D]\n",
    "        bi_encoder_outputs = self.dropout(bi_encoder_outputs)\n",
    "        encoder_outputs, _ = self.encoder_layers[2](bi_encoder_outputs)\n",
    "        # Build all the rest unidi layers of encoder\n",
    "        for i in range(3, self.num_layers + 1):\n",
    "            residual = encoder_outputs\n",
    "            encoder_outputs = self.dropout(encoder_outputs)\n",
    "            # [T,N,D] -> [T,N,D]\n",
    "            encoder_outputs_o, _ = self.encoder_layers[i](encoder_outputs)\n",
    "            encoder_outputs = encoder_outputs_o + residual\n",
    "\n",
    "        return encoder_outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.3 Decoder\n",
    "#### 3.3.1 BahdanauAttention\n",
    "Luong Attention 和 Bahdanau Attention 是最经典的两种注意力机制，我们这里实现Bahdanau Attention机制。\n",
    "\n",
    "\n",
    "Bahdanau本质是一种 加性attention机制，将decoder的隐状态和encoder所有位置输出通过线性组合对齐，得到context向量，用于改善序列到序列的翻译模型。\n",
    "\n",
    "其本质是两层全连接网络，隐藏层激活函数tanh，输出层维度为1。\n",
    "\n",
    "\n",
    "Bahdanau的特点为：\n",
    "\n",
    "- 编码器隐状态 ：编码器对于每一个输入向量产生一个隐状态向量；\n",
    "- 计算对齐分数：使用上一时刻的隐状态$\\boldsymbol s_{t-1}$和编码器每个位置输出$\\boldsymbol x_i$计算对齐分数（使用前馈神经网络计算），编码器最终时刻隐状态可作为解码器初始时刻隐状态；\n",
    "- 概率化对齐分数：解码器上一时刻隐状态$\\boldsymbol s_{t-1}$在编码器每个位置输出的对齐分数，通过softmax转化为概率分布向量；\n",
    "- 计算上下文向量：根据概率分布化的对齐分数，加权编码器各位置输出，得上下文向量$\\boldsymbol c_t$；\n",
    "- 解码器输出：将上下文向量$\\boldsymbol c_t$和上一时刻编码器输出$\\hat y_{t-1} $对应的embedding拼接，作为当前时刻编码器输入，经RNN网络产生新的输出和隐状态，训练过程中有真实目标序列$\\boldsymbol y=(y_1\\cdots y_m)$，多使用$y_{t-1}$取代$\\hat y_{t-1}$作为解码器$t$时刻输入；\n",
    "\n",
    "时刻$t$，解码器的隐状态表示为\n",
    "$$\n",
    "\\boldsymbol s_t = f(\\boldsymbol s_{t-1},\\boldsymbol c_t,y_{t-1})\n",
    "$$\n",
    "\n",
    "时刻$t$的隐状态$\\boldsymbol s_{t-1}$, 对编码器各时刻输出X XX的注意力分数为：\n",
    "$$\n",
    "\\boldsymbol\\alpha_t(\\boldsymbol s_{t-1}, X) = \\text{softmax}(\\tanh( \\boldsymbol s_{t-1}W_{decoder}+X W_{encoder})W_{alignment}),\\quad \\boldsymbol c_t=\\sum_i\\alpha_{ti}\\boldsymbol x_i\n",
    "$$\n",
    "如图1所示，解释了使用Bahdanau注意力机制的解码过程：\n",
    "\n",
    "<center>\n",
    "    <img style=\"border-radius: 0.3125em;\n",
    "    box-shadow: 0 2px 4px 0 rgba(34,36,38,.12),0 2px 10px 0 rgba(34,36,38,.08); width:50%; height:50%;\" \n",
    "    src=\"https://img-blog.csdnimg.cn/20200613134542278.png\">\n",
    "    <br>\n",
    "    <div style=\"color:orange; border-bottom: 1px solid #d9d9d9;\n",
    "    display: inline-block;\n",
    "    color: #999;\n",
    "    padding: 2px;\">图1. Bahdanau注意力机制的解码过程</div>\n",
    "</center>\n",
    "\n",
    "\n",
    "根据以上原理，实现的BahdanauAttention如下所示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.common.dtype as mstype\n",
    "import mindspore.ops.operations as P\n",
    "from mindspore import nn\n",
    "from mindspore.common.tensor import Tensor\n",
    "from mindspore.common.parameter import Parameter\n",
    "from mindspore.common.initializer import Uniform\n",
    "\n",
    "INF = 65504.0\n",
    "\n",
    "\n",
    "class BahdanauAttention(nn.Cell):\n",
    "    \"\"\"\n",
    "    BahdanauAttention的原理实现.\n",
    "\n",
    "    Args:\n",
    "        is_training (bool): Whether to train.\n",
    "        query_size (int): feature dimension for query.\n",
    "        key_size (int): feature dimension for keys.\n",
    "        num_units (int): internal feature dimension.\n",
    "        normalize (bool): Whether to normalize.\n",
    "        initializer_range: range for uniform initializer parameters.\n",
    "\n",
    "    Returns:\n",
    "        Tensor, shape (t_q_length, N, D).\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 is_training,\n",
    "                 query_size,\n",
    "                 key_size,\n",
    "                 num_units,\n",
    "                 normalize=False,\n",
    "                 initializer_range=0.1,\n",
    "                 compute_type=mstype.float16):\n",
    "        super(BahdanauAttention, self).__init__()\n",
    "        self.is_training = is_training\n",
    "        self.mask = None\n",
    "        self.query_size = query_size\n",
    "        self.key_size = key_size\n",
    "        self.normalize = normalize\n",
    "        self.num_units = num_units\n",
    "        self.linear_att = Parameter(Tensor(np.random.uniform(-initializer_range, initializer_range, size=[num_units]),\n",
    "                                           dtype=mstype.float32))\n",
    "        if self.normalize:\n",
    "            self.normalize_scalar = Parameter(Tensor(np.array([1.0 / num_units]), dtype=mstype.float32))\n",
    "            self.normalize_bias = Parameter(Tensor(np.zeros(num_units), dtype=mstype.float32))\n",
    "        self.transpose = P.Transpose()\n",
    "        self.transpose_orders = (1, 0, 2)\n",
    "        self.shape_op = P.Shape()\n",
    "\n",
    "        self.linear_q = nn.Dense(query_size,\n",
    "                                 num_units,\n",
    "                                 has_bias=False,\n",
    "                                 weight_init=Uniform(initializer_range)).to_float(compute_type)\n",
    "\n",
    "        self.linear_k = nn.Dense(key_size,\n",
    "                                 num_units,\n",
    "                                 has_bias=False,\n",
    "                                 weight_init=Uniform(initializer_range)).to_float(compute_type)\n",
    "        self.expand = P.ExpandDims()\n",
    "        self.tile = P.Tile()\n",
    "\n",
    "        self.norm = nn.Norm(axis=-1)\n",
    "        self.mul = P.Mul()\n",
    "        self.matmul = P.MatMul()\n",
    "        self.batchMatmul = P.BatchMatMul()\n",
    "        self.tanh = nn.Tanh()\n",
    "\n",
    "        self.matmul_trans_b = P.BatchMatMul(transpose_b=True)\n",
    "        self.softmax = nn.Softmax(axis=-1)\n",
    "        self.reshape = P.Reshape()\n",
    "        self.cast = P.Cast()\n",
    "\n",
    "    def construct(self, query, keys, attention_mask=None):\n",
    "        \"\"\"\n",
    "        构造attention模块.\n",
    "\n",
    "        Args:\n",
    "            query (Tensor): Shape (t_q_length, N, D).\n",
    "            keys (Tensor): Shape (t_k_length, N, D).\n",
    "            attention_mask: Shape(N, t_k_length).\n",
    "        Returns:\n",
    "            Tensor, shape (t_q_length, N, D).\n",
    "        \"\"\"\n",
    "\n",
    "        # (t_k_length, N, D) -> (N, t_k_length, D).\n",
    "        keys = self.transpose(keys, self.transpose_orders)\n",
    "        # (t_q_length, N, D) -> (N, t_q_length, D).\n",
    "        query_trans = self.transpose(query, self.transpose_orders)\n",
    "\n",
    "        query_shape = self.shape_op(query_trans)\n",
    "        batch_size = query_shape[0]\n",
    "        t_q_length = query_shape[1]\n",
    "        t_k_length = self.shape_op(keys)[1]\n",
    "\n",
    "        # (N, t_q_length, D)\n",
    "        query_trans = self.reshape(query_trans, (batch_size * t_q_length, self.query_size))\n",
    "        if self.is_training:\n",
    "            query_trans = self.cast(query_trans, mstype.float16)\n",
    "        processed_query = self.linear_q(query_trans)\n",
    "        if self.is_training:\n",
    "            processed_query = self.cast(processed_query, mstype.float32)\n",
    "        processed_query = self.reshape(processed_query, (batch_size, t_q_length, self.num_units))\n",
    "        # (N, t_k_length, D)\n",
    "        keys = self.reshape(keys, (batch_size * t_k_length, self.key_size))\n",
    "        if self.is_training:\n",
    "            keys = self.cast(keys, mstype.float16)\n",
    "        processed_key = self.linear_k(keys)\n",
    "        if self.is_training:\n",
    "            processed_key = self.cast(processed_key, mstype.float32)\n",
    "        processed_key = self.reshape(processed_key, (batch_size, t_k_length, self.num_units))\n",
    "\n",
    "        # scores: (N, t_q_length, t_k_length)\n",
    "        scores = self.obtain_score(processed_query, processed_key)\n",
    "        # attention_mask: (N, t_k_length)\n",
    "        mask = attention_mask\n",
    "        if mask is not None:\n",
    "            mask = 1.0 - mask\n",
    "            mask = self.tile(self.expand(mask, 1), (1, t_q_length, 1))\n",
    "            scores += mask * (-INF)\n",
    "        # [batch_size, t_q_length, t_k_length]\n",
    "        scores_softmax = self.softmax(scores)\n",
    "\n",
    "        keys = self.reshape(keys, (batch_size, t_k_length, self.key_size))\n",
    "        if self.is_training:\n",
    "            keys = self.cast(keys, mstype.float16)\n",
    "            scores_softmax_fp16 = self.cast(scores_softmax, mstype.float16)\n",
    "        else:\n",
    "            scores_softmax_fp16 = scores_softmax\n",
    "\n",
    "        # (b, t_q_length, D)\n",
    "        context_attention = self.batchMatmul(scores_softmax_fp16, keys)\n",
    "        # [t_q_length, b, D]\n",
    "        context_attention = self.transpose(context_attention, self.transpose_orders)\n",
    "        if self.is_training:\n",
    "            context_attention = self.cast(context_attention, mstype.float32)\n",
    "\n",
    "        return context_attention, scores_softmax\n",
    "\n",
    "    def obtain_score(self, attention_q, attention_k):\n",
    "        \"\"\"\n",
    "        计算Bahdanau得分\n",
    "\n",
    "        Args:\n",
    "            attention_q: (batch_size, t_q_length, D).\n",
    "            attention_k: (batch_size, t_k_length, D).\n",
    "\n",
    "        returns:\n",
    "            scores: (batch_size, t_q_length, t_k_length).\n",
    "        \"\"\"\n",
    "        batch_size, t_k_length, D = self.shape_op(attention_k)\n",
    "        t_q_length = self.shape_op(attention_q)[1]\n",
    "        # (batch_size, t_q_length, t_k_length, n)\n",
    "        attention_q = self.tile(self.expand(attention_q, 2), (1, 1, t_k_length, 1))\n",
    "        attention_k = self.tile(self.expand(attention_k, 1), (1, t_q_length, 1, 1))\n",
    "        # (batch_size, t_q_length, t_k_length, n)\n",
    "        sum_qk_add = attention_q + attention_k\n",
    "\n",
    "        if self.normalize:\n",
    "            # (batch_size, t_q_length, t_k_length, n)\n",
    "            sum_qk_add = sum_qk_add + self.normalize_bias\n",
    "            linear_att_norm = self.linear_att / self.norm(self.linear_att)\n",
    "            linear_att_norm = self.cast(linear_att_norm, mstype.float32)\n",
    "            linear_att_norm = self.mul(linear_att_norm, self.normalize_scalar)\n",
    "        else:\n",
    "            linear_att_norm = self.linear_att\n",
    "\n",
    "        linear_att_norm = self.expand(linear_att_norm, -1)\n",
    "        sum_qk_add = self.reshape(sum_qk_add, (-1, D))\n",
    "\n",
    "        tanh_sum_qk = self.tanh(sum_qk_add)\n",
    "        if self.is_training:\n",
    "            linear_att_norm = self.cast(linear_att_norm, mstype.float16)\n",
    "            tanh_sum_qk = self.cast(tanh_sum_qk, mstype.float16)\n",
    "\n",
    "        scores_out = self.matmul(tanh_sum_qk, linear_att_norm)\n",
    "\n",
    "        # (N, t_q_length, t_k_length)\n",
    "        scores_out = self.reshape(scores_out, (batch_size, t_q_length, t_k_length))\n",
    "        if self.is_training:\n",
    "            scores_out = self.cast(scores_out, mstype.float32)\n",
    "        return scores_out\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.3.2 RecurrentAttention\n",
    "定义`RecurrentAttention`类，该类将上述定义的动态RNN网络和`BahdanauAttention`结合起来，封装成为一个适用于循环神经网络的`Attention`类。\n",
    "\n",
    "主要原理就是将RNN网络的输出输入到上述`BahdanauAttention`中，为后续计算提供上下文注意力和 Bahdanau score。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RecurrentAttention(nn.Cell):\n",
    "    \"\"\"\n",
    "    构造RecurrentAttention.\n",
    "\n",
    "    Args:\n",
    "        input_size: number of features in input tensor.\n",
    "        context_size: number of features in output from encoder.\n",
    "        hidden_size: internal hidden size.\n",
    "        num_layers: number of layers in LSTM.\n",
    "        dropout: probability of dropout (on input to LSTM layer).\n",
    "        initializer_range: range for the uniform initializer.\n",
    "\n",
    "    Returns:\n",
    "        Tensor, shape (N, T, D).\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 rnn,\n",
    "                 is_training=True,\n",
    "                 input_size=1024,\n",
    "                 context_size=1024,\n",
    "                 hidden_size=1024,\n",
    "                 num_layers=1,\n",
    "                 dropout=0.2,\n",
    "                 initializer_range=0.1):\n",
    "        super(RecurrentAttention, self).__init__()\n",
    "        self.dropout = nn.Dropout(keep_prob=1.0 - dropout)\n",
    "        self.rnn = rnn\n",
    "        self.attn = BahdanauAttention(is_training=is_training,\n",
    "                                      query_size=hidden_size,\n",
    "                                      key_size=hidden_size,\n",
    "                                      num_units=hidden_size,\n",
    "                                      normalize=True,\n",
    "                                      initializer_range=initializer_range,\n",
    "                                      compute_type=mstype.float16)\n",
    "\n",
    "    def construct(self, decoder_embedding, context_key, attention_mask=None, rnn_init_state=None):\n",
    "        # decoder_embedding: [t_q,N,D]\n",
    "        # context: [t_k,N,D]\n",
    "        # attention_mask: [N,t_k]\n",
    "        # [t_q,N,D]\n",
    "        decoder_embedding = self.dropout(decoder_embedding)\n",
    "        rnn_outputs, rnn_state = self.rnn(decoder_embedding, rnn_init_state)\n",
    "        # rnn_outputs:[t_q,b,D], attn_outputs:[t_q,b,D], scores:[b, t_q, t_k], rnn_state:tuple([2,b,D]).\n",
    "        attn_outputs, scores = self.attn(query=rnn_outputs, keys=context_key, attention_mask=attention_mask)\n",
    "        return rnn_outputs, attn_outputs, rnn_state, scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.3.3 GNMTDecoder\n",
    "\n",
    "本文定义GNMT v2模型的Decoder部分基本与GNMT的Decoder部分一致，有如下结构定义：\n",
    "- 具有隐藏大小 1024 和全连接分类器的 4 层单向 LSTM\n",
    "- 残差连接从第 3 层开始\n",
    "- dropout 应用于所有 LSTM 层的输入，dropout 的概率设置为 0.2\n",
    "- LSTM 层的隐藏状态由编码器的最后一个隐藏状态初始化\n",
    "- LSTM 层的权重和偏差初始化为均匀 (-0.1, 0.1) 分布\n",
    "- 全连接分类器的权重和偏差以均匀 (-0.1, 0.1) 分布初始化\n",
    "- Decoder过程中采用了Beam search算法。\n",
    "\n",
    "上文中提到过两个重要的改进，coverage penalty 和 length normalization。以往的Beam search会有利于偏短的结果，谷歌认为这是不合理的，并对长度进行了标准化处理\n",
    "\n",
    "$s(Y,X)=log(P(Y|X))/lp(Y)+cp(X;Y)$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GNMTDecoder(nn.Cell):\n",
    "    \"\"\"\n",
    "    Transformer decoder部分的实现.\n",
    "\n",
    "    Args:\n",
    "        attn_embed_dim (int): Dimensions of attention layer.\n",
    "        decoder_layers (int): Decoder layers.\n",
    "        num_attn_heads (int): Attention heads number.\n",
    "        intermediate_size (int): Hidden size of FFN.\n",
    "        attn_dropout_prob (float): Dropout rate in attention. Default: 0.1.\n",
    "        initializer_range (float): Initial range. Default: 0.02.\n",
    "        dropout_prob (float): Dropout rate between layers. Default: 0.1.\n",
    "        hidden_act (str): Non-linear activation function in FFN. Default: \"relu\".\n",
    "        compute_type (mstype): Mindspore data type. Default: mstype.float32.\n",
    "\n",
    "    Returns:\n",
    "        Tensor, shape of (N, T', D).\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 config,\n",
    "                 is_training: bool,\n",
    "                 use_one_hot_embeddings: bool = False,\n",
    "                 initializer_range=0.1,\n",
    "                 infer_beam_width=1,\n",
    "                 compute_type=mstype.float16):\n",
    "        super(GNMTDecoder, self).__init__()\n",
    "\n",
    "        self.is_training = is_training\n",
    "        self.attn_embed_dim = config.hidden_size\n",
    "        self.num_layers = config.num_hidden_layers\n",
    "        self.hidden_dropout_prob = config.hidden_dropout_prob\n",
    "        self.vocab_size = config.vocab_size\n",
    "        self.seq_length = config.max_decode_length\n",
    "        # batchsize* beam_width for beam_search.\n",
    "        self.batch_size = config.batch_size * infer_beam_width\n",
    "        self.word_embed_dim = config.hidden_size\n",
    "        self.transpose = P.Transpose()\n",
    "        self.transpose_orders = (1, 0, 2)\n",
    "        self.reshape = P.Reshape()\n",
    "        self.concat = P.Concat(axis=-1)\n",
    "        self.state_concat = P.Concat(axis=0)\n",
    "        self.all_decoder_state = Tensor(np.zeros([self.num_layers, 2, self.batch_size, config.hidden_size]),\n",
    "                                        mstype.float32)\n",
    "\n",
    "        decoder_layers = []\n",
    "        for i in range(0, self.num_layers):\n",
    "            if i == 0:\n",
    "                # the inputs is [T,D,N]\n",
    "                scaler = 1\n",
    "            else:\n",
    "                # the inputs is [T,D,2N]\n",
    "                scaler = 2\n",
    "            layer = DynamicRNNNet(seq_length=self.seq_length,\n",
    "                                  batchsize=self.batch_size,\n",
    "                                  word_embed_dim=scaler * self.word_embed_dim,\n",
    "                                  hidden_size=self.word_embed_dim)\n",
    "            decoder_layers.append(layer)\n",
    "        self.decoder_layers = nn.CellList(decoder_layers)\n",
    "\n",
    "        self.att_rnn = RecurrentAttention(rnn=self.decoder_layers[0],\n",
    "                                          is_training=is_training,\n",
    "                                          input_size=self.word_embed_dim,\n",
    "                                          context_size=self.attn_embed_dim,\n",
    "                                          hidden_size=self.attn_embed_dim,\n",
    "                                          num_layers=1,\n",
    "                                          dropout=config.attention_dropout_prob)\n",
    "\n",
    "        self.dropout = nn.Dropout(keep_prob=1.0 - config.hidden_dropout_prob)\n",
    "\n",
    "        self.classifier = nn.Dense(config.hidden_size,\n",
    "                                   config.vocab_size,\n",
    "                                   has_bias=True,\n",
    "                                   weight_init=Uniform(initializer_range),\n",
    "                                   bias_init=Uniform(initializer_range)).to_float(compute_type)\n",
    "        self.cast = P.Cast()\n",
    "        self.shape_op = P.Shape()\n",
    "        self.expand = P.ExpandDims()\n",
    "\n",
    "    def construct(self, tgt_embeddings, encoder_outputs, attention_mask=None,\n",
    "                  decoder_init_state=None):\n",
    "        \"\"\"Decoder.\"\"\"\n",
    "        # tgt_embeddings: [T',N,D], encoder_outputs: [T,N,D], attention_mask: [N,T].\n",
    "        query_shape = self.shape_op(tgt_embeddings)\n",
    "        if decoder_init_state is None:\n",
    "            hidden_state = self.all_decoder_state\n",
    "        else:\n",
    "            hidden_state = decoder_init_state\n",
    "        # x:[t_q,b,D], attn:[t_q,b,D], scores:[b, t_q, t_k], state_0:[2,b,D].\n",
    "        x, attn, state_0, scores = self.att_rnn(decoder_embedding=tgt_embeddings, context_key=encoder_outputs,\n",
    "                                                attention_mask=attention_mask, rnn_init_state=hidden_state[0, :, :, :])\n",
    "        x = self.concat((x, attn))\n",
    "        x = self.dropout(x)\n",
    "        decoder_outputs, state_1 = self.decoder_layers[1](x, hidden_state[1, :, :, :])\n",
    "\n",
    "        all_decoder_state = self.state_concat((self.expand(state_0, 0), self.expand(state_1, 0)))\n",
    "\n",
    "        for i in range(2, self.num_layers):\n",
    "            residual = decoder_outputs\n",
    "            decoder_outputs = self.concat((decoder_outputs, attn))\n",
    "\n",
    "            decoder_outputs = self.dropout(decoder_outputs)\n",
    "            # 1st unidirectional layer. encoder_outputs: [T,N,D]\n",
    "            decoder_outputs, decoder_state = self.decoder_layers[i](decoder_outputs, hidden_state[i, :, :, :])\n",
    "            decoder_outputs += residual\n",
    "            all_decoder_state = self.state_concat((all_decoder_state, self.expand(decoder_state, 0)))\n",
    "\n",
    "        # [m, batch_size * beam_width, D]\n",
    "        decoder_outputs = self.reshape(decoder_outputs, (-1, self.word_embed_dim))\n",
    "        if self.is_training:\n",
    "            decoder_outputs = self.cast(decoder_outputs, mstype.float16)\n",
    "        decoder_outputs = self.classifier(decoder_outputs)\n",
    "        if self.is_training:\n",
    "            decoder_outputs = self.cast(decoder_outputs, mstype.float32)\n",
    "        # [m, batch_size * beam_width, V]\n",
    "        decoder_outputs = self.reshape(decoder_outputs, (query_shape[0], query_shape[1], self.vocab_size))\n",
    "        # all_decoder_state:[4,2,b,D]\n",
    "        return decoder_outputs, all_decoder_state, scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.4 推理过程\n",
    "#### 3.4.1 CreateAttentionPaddingsFromInputPaddings\n",
    "首先封装`CreateAttentionPaddingsFromInputPaddings`类，该类负责产出一个[batch_size, seq_length, seq_length]的矩阵,主要为了在计算self-attention系数时,padding位置不参与attention系数更新,具体padding位置由input_mask负责记录。\n",
    "\n",
    "可以用下图来解释，假设下图中的向量是8个句子做完embedding后的向量，后面的0代表句子的长度已结束。此时第一个句子的第一个编码在后面做self-Attention所需要和该句子中的其他向量计算，那么该和哪些向量计算呢？\n",
    "\n",
    "此处就是用新增加的纬度来表示需要计算的词向量，如图2中，下部分是转换成3D后新增加的一个向量来表示和哪些词进行计算，1代表能计算，0代表不进行计算。\n",
    "\n",
    "<center>\n",
    "    <img style=\"border-radius: 0.3125em;\n",
    "    box-shadow: 0 2px 4px 0 rgba(34,36,38,.12),0 2px 10px 0 rgba(34,36,38,.08);\" \n",
    "    src=\"https://img-blog.csdnimg.cn/20200408205742305.png\">\n",
    "    <br>\n",
    "    <div style=\"color:orange; border-bottom: 1px solid #d9d9d9;\n",
    "    display: inline-block;\n",
    "    color: #999;\n",
    "    padding: 2px;\">图2. 向量计算图</div>\n",
    "</center>\n",
    "\n",
    "根据上述原理，定义Attention的Padding封装类如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CreateAttentionPaddingsFromInputPaddings(nn.Cell):\n",
    "    \"\"\"\n",
    "    根据输入掩码创建注意力掩码.\n",
    "\n",
    "    Args:\n",
    "        config: Config class.\n",
    "\n",
    "    Returns:\n",
    "        Tensor, shape of (N, T, T).\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 config,\n",
    "                 is_training=True):\n",
    "        super(CreateAttentionPaddingsFromInputPaddings, self).__init__()\n",
    "\n",
    "        self.is_training = is_training\n",
    "        self.input_mask = None\n",
    "        self.cast = P.Cast()\n",
    "        self.shape = P.Shape()\n",
    "        self.reshape = P.Reshape()\n",
    "        self.batch_matmul = P.BatchMatMul()\n",
    "        self.multiply = P.Mul()\n",
    "        self.shape = P.Shape()\n",
    "        # mask future positions\n",
    "        ones = np.ones(shape=(config.batch_size, config.seq_length, config.seq_length))\n",
    "        self.lower_triangle_mask = Tensor(np.tril(ones), dtype=mstype.float32)\n",
    "\n",
    "    def construct(self, input_mask, mask_future=False):\n",
    "        \"\"\"\n",
    "        Construct network.\n",
    "\n",
    "        Args:\n",
    "            input_mask (Tensor): Tensor mask vectors with shape (N, T).\n",
    "            mask_future (bool): Whether mask future (for decoder training).\n",
    "\n",
    "        Returns:\n",
    "            Tensor, shape of (N, T, T).\n",
    "        \"\"\"\n",
    "        input_shape = self.shape(input_mask)\n",
    "        # Add this for infer as the seq_length will increase.\n",
    "        shape_right = (input_shape[0], 1, input_shape[1])\n",
    "        shape_left = input_shape + (1,)\n",
    "        if self.is_training:\n",
    "            input_mask = self.cast(input_mask, mstype.float16)\n",
    "        mask_left = self.reshape(input_mask, shape_left)\n",
    "        mask_right = self.reshape(input_mask, shape_right)\n",
    "\n",
    "        attention_mask = self.batch_matmul(mask_left, mask_right)\n",
    "        if self.is_training:\n",
    "            attention_mask = self.cast(attention_mask, mstype.float32)\n",
    "\n",
    "        if mask_future:\n",
    "            attention_mask = self.multiply(attention_mask, self.lower_triangle_mask)\n",
    "\n",
    "        return attention_mask\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.4.2 Beam Serach Decoder\n",
    "\n",
    "在机器翻译任务中，都需要进行词或者字符序列的生成。我们想要找到最匹配的文本序列，而要在全局搜索这个最有解空间，往往是不可能的（因为词典太大），建设生成序列长度为N，词典大小为V， 则复杂度为 V^N次方。这实际上是一个NP难题。退而求其次，我们使用启发式算法，来找到可能的最优解，或者说足够好的解。\n",
    "\n",
    "传统的获取解码器输出的过程中，每次只选择概率最大的那个结果，作为当前时间步的输出，等到输出结束，我们会发现，整个句子可能并不通顺。虽然在每一个时间步上的输出确实是概率最大的，但是整体的概率确不一定最大的，我们经常把它叫做greedy search(贪心算法)。\n",
    "\n",
    "为了解决上述的问题，可以考虑计算全部的输出的概率乘积，选择最大概率的路径，这样可以达到全局最优解。但是这样的话，意味着如果句子很长，候选词很多，那么需要保存的数据就会非常大，需要计算的数据量就很大。Beam Search 就是介于上述两种方法的一个这种的方法，假设Beam width=2，表示每次保存的最大的概率的个数，这里每次保存两个，在下一个时间步骤一样，也是保留两个，这样就可以达到约束搜索空间大小的目的，从而提高算法的效率。Beam Search也不是全局最优。【维特比算法是全局最优】\n",
    "\n",
    "使用一个树状图来表示每个time step的可能输出，其中的数字表示是条件概率。黄色的箭头表示的是一种greedy search，概率并不是最大的，如果把beam width设置为2，那么后续可以找到绿色路径的结果，这个结果是最大的\n",
    "\n",
    "<center>\n",
    "    <img style=\"border-radius: 0.3125em;\n",
    "    box-shadow: 0 2px 4px 0 rgba(34,36,38,.12),0 2px 10px 0 rgba(34,36,38,.08);\" \n",
    "    src=\"https://img-blog.csdnimg.cn/2021032016512997.png\">\n",
    "    <br>\n",
    "    <div style=\"color:orange; border-bottom: 1px solid #d9d9d9;\n",
    "    display: inline-block;\n",
    "    color: #999;\n",
    "    padding: 2px;\">图2. Beam Search 搜索示意图</div>\n",
    "</center>\n",
    "\n",
    "具体Beam Search算法详解，可以参考文献[5]。\n",
    "\n",
    "上述我们提到，Beam Search算法在文本生成中用得比较多，用于选择较优的结果，但可能并不是最优的。为了优化该方案，Google为传统Beam Search算法增加了**长度归一化**(length normalization) 和 **覆盖惩罚**(coverage penalty)。\n",
    "\n",
    "在采用对数似然作为得分函数时，Beam Search 通常会倾向于更短的序列。因为对数似然是负数，越长的序列在计算 score 时得分越低 (加的负数越多)。在得分函数中引入 length normalization 对长度进行归一化可以解决这一问题。coverage penalty 主要用于使用 Attention 的场合，通过 coverage penalty 可以让 Decoder 均匀地关注于输入序列 x 的每一个 token，防止一些 token 获得过多的 Attention。把对数似然、length normalization 和 coverage penalty 结合在一起，可以得到新的得分函数，如下面的公式所示：\n",
    "\n",
    "$$\n",
    "score(x,y) = log(P(y|x))/lp(y) + cp(x,y) \\\\\n",
    "lp(y) = \\frac{(5+|y|)^{\\alpha}}{(5+1)^{\\alpha}}\n",
    "\\\\\n",
    "cp(x, y) = \\beta * \\sum_{i=1}^{|x|}log(min(\\sum_{j=1}^{|y|}a_{ij}))\n",
    "$$\n",
    "\n",
    "\n",
    "其中$a_{ij}$表示第j个target word在第i个source word上的Attention值， $lp$表示长度归一化，$cp$表示覆盖惩罚。\n",
    "\n",
    "有关Beam Search的具体内容可以参考：引用\n",
    "\n",
    "##### 3.4.2.1 LengthPenalty\n",
    "定义`LengthPenalty`类来实现Google提出的长度惩罚操作，方便后续使用集束搜索算法。\n",
    "\n",
    "这里首先使用`ExpandDims()`对输入的`input_tensor` 在给定的轴上添加额外维度，然后使用`Tile()`方法来按照给定的次数复制输入Tensor到指定`beam_width`的光速宽度，然后使用`Reshape()`操作重新划分向量形状。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LengthPenalty(nn.Cell):\n",
    "    \"\"\"\n",
    "    Length penalty.\n",
    "\n",
    "    Args:\n",
    "        weight (float): The length penalty weight.\n",
    "        compute_type (mstype): Mindspore data type. Default: mstype.float32.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, weight=1.0, compute_type=mstype.float32):\n",
    "        super(LengthPenalty, self).__init__()\n",
    "        self.weight = weight\n",
    "        self.add = P.Add()\n",
    "        self.pow = P.Pow()\n",
    "        self.div = P.RealDiv()\n",
    "        self.five = Tensor(5.0, mstype.float32)\n",
    "        self.six = Tensor(6.0, mstype.float32)\n",
    "        self.cast = P.Cast()\n",
    "\n",
    "    def construct(self, length_tensor):\n",
    "        \"\"\"\n",
    "        Process source sentence\n",
    "\n",
    "        Inputs:\n",
    "            length_tensor (Tensor):  the input tensor.\n",
    "\n",
    "        Returns:\n",
    "            Tensor, after punishment of length.\n",
    "        \"\"\"\n",
    "        length_tensor = self.cast(length_tensor, mstype.float32)\n",
    "        output = self.add(length_tensor, self.five)\n",
    "        output = self.div(output, self.six)\n",
    "        output = self.pow(output, self.weight)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "##### 3.4.2.1 TileBeam\n",
    "定义`TileBeam`类来实现集束平铺操作，方便后续使用集束搜索算法。\n",
    "\n",
    "这里首先使用`ExpandDims()`对输入的`input_tensor` 在给定的轴上添加额外维度，然后使用`Tile()`方法来按照给定的次数复制输入Tensor到指定`beam_width`的光速宽度，然后使用`Reshape()`操作重新划分向量形状。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TileBeam(nn.Cell):\n",
    "    \"\"\"\n",
    "    集束平铺操作.\n",
    "\n",
    "    Args:\n",
    "        beam_width (int): The Number of beam.\n",
    "        compute_type (mstype): Mindspore data type. Default: mstype.float32.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, beam_width, compute_type=mstype.float32):\n",
    "        super(TileBeam, self).__init__()\n",
    "        self.beam_width = beam_width\n",
    "\n",
    "        self.expand = P.ExpandDims()\n",
    "        self.tile = P.Tile()\n",
    "        self.reshape = P.Reshape()\n",
    "        self.shape = P.Shape()\n",
    "\n",
    "    def construct(self, input_tensor):\n",
    "        \"\"\"\n",
    "        处理源句子\n",
    "\n",
    "        Inputs:\n",
    "            input_tensor (Tensor):  with shape (N, T, D).\n",
    "\n",
    "        Returns:\n",
    "            Tensor, tiled tensor.\n",
    "        \"\"\"\n",
    "        shape = self.shape(input_tensor)\n",
    "        # 新增一个维度\n",
    "        input_tensor = self.expand(input_tensor, 1)\n",
    "        # 获取tile的shape: [1, beam, ...]\n",
    "        tile_shape = (1,) + (self.beam_width,)\n",
    "        for _ in range(len(shape) - 1):\n",
    "            # 将tile_shape逐次扩展\n",
    "            tile_shape = tile_shape + (1,)\n",
    "        # 将input_tensor平铺到tile_shape上\n",
    "        output = self.tile(input_tensor, tile_shape)\n",
    "        # reshape to [batch*beam, ...]\n",
    "        out_shape = (shape[0] * self.beam_width,) + shape[1:]\n",
    "        output = self.reshape(output, out_shape)\n",
    "\n",
    "        return output\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 3.4.2.2 SaturateCast\n",
    "定义`SaturateCast`类，该类是一个工具类，主要是用来将一种数据类型安全的转换成另一种数据类型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SaturateCast(nn.Cell):\n",
    "    \"\"\"构建安全类型转换器\"\"\"\n",
    "\n",
    "    def __init__(self, dst_type=mstype.float32):\n",
    "        super(SaturateCast, self).__init__()\n",
    "        self.cast = P.Cast()\n",
    "        self.dst_type = dst_type\n",
    "\n",
    "    def construct(self, x):\n",
    "        return self.cast(x, self.dst_type)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 3.4.2.3 PredLogProbs\n",
    "定义`PredLogProbs`类，该类利用Log Softmax激活函数按元素计算，输入经Softmax函数、Log函数转换后得到预测的概率取值，值的范围在[-inf,0)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PredLogProbs(nn.Cell):\n",
    "    \"\"\"\n",
    "    Get log probs.\n",
    "\n",
    "    Args:\n",
    "        batch_size (int): Batch size of input dataset.\n",
    "        seq_length (int): The length of sequences.\n",
    "        width (int): Number of parameters of a layer\n",
    "        compute_type (int): Type of input type.\n",
    "        dtype (int): Type of MindSpore output type.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 batch_size,\n",
    "                 seq_length,\n",
    "                 width,\n",
    "                 compute_type=mstype.float32,\n",
    "                 dtype=mstype.float32):\n",
    "        super(PredLogProbs, self).__init__()\n",
    "        self.batch_size = batch_size\n",
    "        self.seq_length = seq_length\n",
    "        self.width = width\n",
    "        self.compute_type = compute_type\n",
    "        self.dtype = dtype\n",
    "        self.log_softmax = nn.LogSoftmax(axis=-1)\n",
    "        self.cast = P.Cast()\n",
    "\n",
    "    def construct(self, logits):\n",
    "        \"\"\"\n",
    "        计算log_softmax.\n",
    "\n",
    "        Inputs:\n",
    "            input_tensor (Tensor): A batch of sentences with shape (N, T).\n",
    "            output_weights (Tensor): A batch of masks with shape (N, T).\n",
    "\n",
    "        Returns:\n",
    "            Tensor, the prediction probability with shape (N, T').\n",
    "        \"\"\"\n",
    "        log_probs = self.log_softmax(logits)\n",
    "        return log_probs\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 3.4.2.4 BeamDecoderStep\n",
    "定义`BeamDecoderStep`类，该类定义了Beam Search的decoder过程。\n",
    "\n",
    "该类在Beam Decoder的每一个时间步利用上述的GNMT Decoder方法，将输入的词嵌入序列`input_embedding`，和上述的GNMTEncoder所得到的`encoder_states`及`encoder_attention_mask`放入GNMT Decoder中得到 `decoder_output`, `all_decoder_state`,`scores`三个参数，并利用上述定义的`PredLogProbs`计算投影面得到相似概率。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BeamDecoderStep(nn.Cell):\n",
    "    \"\"\"\n",
    "    多层 transformer的单步解码器.\n",
    "\n",
    "    Args:\n",
    "        config: The config of Transformer.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 config,\n",
    "                 use_one_hot_embeddings,\n",
    "                 compute_type=mstype.float32):\n",
    "        super(BeamDecoderStep, self).__init__(auto_prefix=True)\n",
    "\n",
    "        self.vocab_size = config.vocab_size\n",
    "        self.word_embed_dim = config.hidden_size\n",
    "        self.embedding_lookup = EmbeddingLookup(\n",
    "            is_training=False,\n",
    "            vocab_size=config.vocab_size,\n",
    "            embed_dim=self.word_embed_dim,\n",
    "            use_one_hot_embeddings=use_one_hot_embeddings)\n",
    "\n",
    "        self.projection = PredLogProbs(\n",
    "            batch_size=config.batch_size * config.beam_width,\n",
    "            seq_length=1,\n",
    "            width=config.vocab_size,\n",
    "            compute_type=mstype.float16)\n",
    "\n",
    "        self.seq_length = config.max_decode_length\n",
    "        self.decoder = GNMTDecoder(config,\n",
    "                                   is_training=False,\n",
    "                                   infer_beam_width=config.beam_width)\n",
    "\n",
    "        self.ones_like = P.OnesLike()\n",
    "        self.shape = P.Shape()\n",
    "\n",
    "        self.create_att_paddings_from_input_paddings = CreateAttentionPaddingsFromInputPaddings(config,\n",
    "                                                                                                is_training=False)\n",
    "        self.expand = P.ExpandDims()\n",
    "        self.multiply = P.Mul()\n",
    "\n",
    "        ones = np.ones(shape=(config.max_decode_length, config.max_decode_length))\n",
    "        self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32)\n",
    "\n",
    "        self.cast_compute_type = SaturateCast(dst_type=compute_type)\n",
    "\n",
    "        self.transpose = P.Transpose()\n",
    "        self.transpose_orders = (1, 0, 2)\n",
    "\n",
    "    def construct(self, input_ids, enc_states, enc_attention_mask, decoder_hidden_state=None):\n",
    "        \"\"\"\n",
    "        获取 log 计算的概率.\n",
    "\n",
    "        Args:\n",
    "            input_ids: [batch_size * beam_width, m]\n",
    "            enc_states: [batch_size * beam_width, T, D]\n",
    "            enc_attention_mask: [batch_size * beam_width, T]\n",
    "            decoder_hidden_state: [decoder_layers_nums, 2, batch_size * beam_width, hidden_size].\n",
    "\n",
    "        Returns:\n",
    "            Tensor, the log_probs. [batch_size * beam_width, 1, vocabulary_size]\n",
    "        \"\"\"\n",
    "\n",
    "        # 处理词嵌入过程. input_embedding: [batch_size * beam_width, m, D], embedding_tables: [V, D]\n",
    "        input_embedding, _ = self.embedding_lookup(input_ids)\n",
    "        input_embedding = self.cast_compute_type(input_embedding)\n",
    "\n",
    "        input_shape = self.shape(input_ids)\n",
    "        input_len = input_shape[1]\n",
    "        # [m, batch_size * beam_width, D]\n",
    "        input_embedding = self.transpose(input_embedding, self.transpose_orders)\n",
    "        enc_states = self.transpose(enc_states, self.transpose_orders)\n",
    "\n",
    "        # decoder_output: [m, batch_size*beam_width, V], scores:[b, t_q, t_k], all_decoder_state:[4,2,b*beam_width,D]\n",
    "        decoder_output, all_decoder_state, scores = self.decoder(input_embedding, enc_states, enc_attention_mask, decoder_hidden_state)\n",
    "        # [batch_size * beam_width, m, v]\n",
    "        decoder_output = self.transpose(decoder_output, self.transpose_orders)\n",
    "\n",
    "        # take the last step, [batch_size * beam_width, 1, V]\n",
    "        decoder_output = decoder_output[:, (input_len - 1):input_len, :]\n",
    "\n",
    "        # 投影并计算log的概率\n",
    "        log_probs = self.projection(decoder_output)\n",
    "\n",
    "        # [batch_size * beam_width, 1, vocabulary_size]\n",
    "        return log_probs, all_decoder_state, scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.4.6 BeamSearchDecoder\n",
    "\n",
    "定义Beam Search Decoder类执行Beam Search 算法，该类使用到了上述定义的`BeamDecoderStep`，同时在该类中定义了`one_step`函数，执行Beam Search中的每一个时间步。\n",
    "\n",
    "\n",
    "`one_step`函数能够执行beam search解码步骤，该步骤使用 cell 来计算概率，然后执行beam search步骤以计算得分并选择候选标记ID。\n",
    "\n",
    "与上述所述的BeamSearch原理一致，我们这里将定义的GNMT Decoder计算出来的文本词向量定义为需要找到最相似的文本序列，即可能性最大的文本序列。\n",
    "\n",
    "对于每个源句（样本）从 ids 中根据其对应的 scores 选择当前时间步 top-K （K 是 beam_size）的候选词，加入输入用于特殊处理到达结束的翻译候选。\n",
    "\n",
    "首先将源词嵌入、encoder当前时间步的输出、encoder的attenton_mask输入网络进行运算。\n",
    "\n",
    "然后将输入句子通过`embedding_lookup`进行词嵌入， 随后将词嵌入和encoder输出、encoder的attention_mask共同输入到上述定义的GNMT Decoder类中进行解码，最后计算解码完毕的词向量概率。\n",
    "\n",
    "在`construct`函数中，我们规定了Beam Search中每一步需要执行的操作，首先执行`one_step`获取当前时间步的分数和状态，然后通过添加长度归一化和覆盖惩罚以便尽可能准确匹配。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BeamSearchDecoder(nn.Cell):\n",
    "    \"\"\"\n",
    "    集束搜索的decoder.\n",
    "\n",
    "    Args:\n",
    "        batch_size (int): Batch size of input dataset.\n",
    "        seq_length (int): Length of input sequence.\n",
    "        vocab_size (int): The shape of each embedding vector.\n",
    "        decoder (Cell): The GNMT decoder.\n",
    "        beam_width (int): Beam width for beam search in inferring. 默认值: 4.\n",
    "        decoder_layers_nums (int): The nums of decoder layers.\n",
    "        length_penalty_weight (float): Penalty for sentence length. 默认值: 0.6.\n",
    "        max_decode_length (int): Max decode length for inferring. 默认值: 64.\n",
    "        sos_id (int): The index of start label <SOS>. 默认值: 1.\n",
    "        eos_id (int): The index of end label <EOS>. 默认值: 2.\n",
    "        compute_type (:class:`mindspore.dtype`): Compute type. 默认值: mstype.float32.\n",
    "\n",
    "    Returns:\n",
    "        Tensor, predictions output.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 batch_size,\n",
    "                 seq_length,\n",
    "                 vocab_size,\n",
    "                 decoder,\n",
    "                 beam_width=4,\n",
    "                 decoder_layers_nums=4,\n",
    "                 length_penalty_weight=0.6,\n",
    "                 cov_penalty_factor=0.1,\n",
    "                 hidden_size=1024,\n",
    "                 max_decode_length=64,\n",
    "                 sos_id=2,\n",
    "                 eos_id=3,\n",
    "                 is_using_while=True,\n",
    "                 compute_type=mstype.float32):\n",
    "        super(BeamSearchDecoder, self).__init__()\n",
    "\n",
    "        self.encoder_length = seq_length\n",
    "        self.hidden_size = hidden_size\n",
    "        self.batch_size = batch_size\n",
    "        self.vocab_size = vocab_size\n",
    "        self.beam_width = beam_width\n",
    "        self.decoder_layers_nums = decoder_layers_nums\n",
    "        self.length_penalty_weight = length_penalty_weight\n",
    "        self.cov_penalty_factor = cov_penalty_factor\n",
    "        self.max_decode_length = max_decode_length\n",
    "        self.decoder = decoder\n",
    "        self.is_using_while = is_using_while\n",
    "\n",
    "        self.add = P.Add()\n",
    "        self.expand = P.ExpandDims()\n",
    "        self.reshape = P.Reshape()\n",
    "        self.shape_flat = (-1,)\n",
    "        self.shape = P.Shape()\n",
    "\n",
    "        self.zero_tensor = Tensor(np.zeros([batch_size, beam_width]), mstype.float32)\n",
    "        self.ninf_tensor = Tensor(np.full([batch_size, beam_width], -INF), mstype.float32)\n",
    "\n",
    "        self.select = P.Select()\n",
    "        self.flat_shape = (batch_size, beam_width * vocab_size)\n",
    "        self.topk = P.TopK(sorted=True)\n",
    "        self.vocab_size_tensor = Tensor(self.vocab_size, mstype.int32)\n",
    "        self.real_div = P.RealDiv()\n",
    "        self.equal = P.Equal()\n",
    "        self.eos_ids = Tensor(np.full([batch_size, beam_width], eos_id), mstype.int32)\n",
    "\n",
    "        beam_ids = np.tile(np.arange(beam_width).reshape((1, beam_width)), [batch_size, 1])\n",
    "        self.beam_ids = Tensor(beam_ids, mstype.int32)\n",
    "\n",
    "        batch_ids = np.arange(batch_size * beam_width).reshape((batch_size, beam_width)) // beam_width\n",
    "        self.batch_ids = Tensor(batch_ids, mstype.int32)\n",
    "\n",
    "        self.concat = P.Concat(axis=-1)\n",
    "        self.gather_nd = P.GatherNd()\n",
    "\n",
    "        self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32)\n",
    "        if self.is_using_while:\n",
    "            self.start = Tensor(0, dtype=mstype.int32)\n",
    "            self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length + 1], sos_id),\n",
    "                                   mstype.int32)\n",
    "        else:\n",
    "            self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32)\n",
    "\n",
    "        init_scores = np.tile(np.array([[0.] + [-INF] * (beam_width - 1)]), [batch_size, 1])\n",
    "        self.init_scores = Tensor(init_scores, mstype.float32)\n",
    "        self.init_finished = Tensor(np.zeros([batch_size, beam_width], dtype=bool))\n",
    "        self.init_length = Tensor(np.zeros([batch_size, beam_width], dtype=np.int32))\n",
    "\n",
    "        self.length_penalty = LengthPenalty(weight=length_penalty_weight)\n",
    "\n",
    "        self.one = Tensor(1, mstype.int32)\n",
    "        self.prob_concat = P.Concat(axis=1)\n",
    "        self.cast = P.Cast()\n",
    "        self.decoder_hidden_state = Tensor(np.zeros([self.decoder_layers_nums, 2,\n",
    "                                                     self.batch_size * self.beam_width,\n",
    "                                                     hidden_size]), mstype.float32)\n",
    "\n",
    "        self.zeros_scores = Tensor(np.zeros([batch_size, beam_width], dtype=np.float32))\n",
    "        self.active_index = Tensor(np.ones([batch_size, beam_width], dtype=np.int32))\n",
    "        self.init_zeros = Tensor(np.zeros([batch_size, beam_width], dtype=np.int32))\n",
    "        self.init_ones = Tensor(np.ones([batch_size, beam_width], dtype=np.float32))\n",
    "\n",
    "        self.accu_attn_scores = Tensor(np.zeros([batch_size, beam_width, self.encoder_length], dtype=np.float32))\n",
    "\n",
    "        self.zeros = Tensor([0], mstype.int32)\n",
    "        self.eos_tensor = Tensor(np.full([batch_size, beam_width, beam_width], eos_id), mstype.int32)\n",
    "\n",
    "        self.ones_3d = Tensor(np.full([batch_size, beam_width, self.encoder_length], 1), mstype.float32)\n",
    "        self.neg_inf_3d = Tensor(np.full([batch_size, beam_width, self.encoder_length], -INF), mstype.float32)\n",
    "        self.zeros_3d = Tensor(np.full([batch_size, beam_width, self.encoder_length], 0), mstype.float32)\n",
    "        self.zeros_2d = Tensor(np.full([batch_size * beam_width, self.encoder_length], 0), mstype.int32)\n",
    "        self.argmin = P.ArgMinWithValue(axis=1)\n",
    "        self.reducesum = P.ReduceSum()\n",
    "        self.div = P.Div()\n",
    "        self.shape_op = P.Shape()\n",
    "        self.mul = P.Mul()\n",
    "        self.log = P.Log()\n",
    "        self.less = P.Less()\n",
    "        self.tile = P.Tile()\n",
    "        self.noteq = P.Neg()\n",
    "        self.zeroslike = P.ZerosLike()\n",
    "        self.greater_equal = P.GreaterEqual()\n",
    "        self.sub = P.Sub()\n",
    "\n",
    "    def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs,\n",
    "                 state_seq, state_length, idx=None, decoder_hidden_state=None, accu_attn_scores=None,\n",
    "                 state_finished=None):\n",
    "        \"\"\"\n",
    "        Beam search的单步输出.\n",
    "\n",
    "        Inputs:\n",
    "            cur_input_ids (Tensor):  with shape (batch_size * beam_width, 1).\n",
    "            enc_states (Tensor):  with shape (batch_size * beam_width, T, D).\n",
    "            enc_attention_mask (Tensor):  with shape (batch_size * beam_width, T).\n",
    "            state_log_probs (Tensor):  with shape (batch_size, beam_width).\n",
    "            state_seq (Tensor):  with shape (batch_size, beam_width, m).\n",
    "            state_length (Tensor):  with shape (batch_size, beam_width).\n",
    "            idx (Tensor):  with shape ().\n",
    "            decoder_hidden_state (Tensor): with shape (decoder_layer_num, 2, batch_size * beam_width, D).\n",
    "            accu_attn_scores (Tensor): with shape (batchsize, beam_width, seq_length).\n",
    "            state_finished (Tensor):  with shape (batch_size, beam_width).\n",
    "        \"\"\"\n",
    "\n",
    "        # log_probs, [batch_size * beam_width, 1, V]\n",
    "        log_probs, all_decoder_state, attn = self.decoder(cur_input_ids, enc_states, enc_attention_mask,\n",
    "                                                          decoder_hidden_state)\n",
    "        # consider attention_scores\n",
    "        attn = self.reshape(attn, (-1, self.beam_width, self.encoder_length))\n",
    "        state_finished_attn = self.cast(state_finished, mstype.int32)\n",
    "        attn_mask_0 = self.tile(self.expand(state_finished_attn, 2), (1, 1, self.encoder_length))\n",
    "        attn_mask_0 = self.cast(attn_mask_0, mstype.bool_)\n",
    "        attn_new = self.select(attn_mask_0, self.zeros_3d, attn)\n",
    "        accu_attn_scores = self.add(accu_attn_scores, attn_new)\n",
    "\n",
    "        # log_probs: [batch_size, beam_width, V]\n",
    "        log_probs = self.reshape(log_probs, (-1, self.beam_width, self.vocab_size))\n",
    "        # select topk indices, [batch_size, beam_width, V]\n",
    "        total_log_probs = self.add(log_probs, self.expand(state_log_probs, -1))\n",
    "        # mask finished beams, [batch_size, beam_width]\n",
    "        # t-1 has finished\n",
    "        mask_tensor = self.select(state_finished, self.ninf_tensor, self.zero_tensor)\n",
    "        # save the t-1 probability\n",
    "        total_log_probs = self.add(total_log_probs, self.expand(mask_tensor, -1))\n",
    "        # [batch, beam*vocab]\n",
    "        flat_scores = self.reshape(total_log_probs, (-1, self.beam_width * self.vocab_size))\n",
    "        # 选择top-k, [batch, beam]\n",
    "        topk_scores, topk_indices = self.topk(flat_scores, self.beam_width)\n",
    "\n",
    "        temp = topk_indices\n",
    "        beam_indices = self.zeroslike(topk_indices)\n",
    "        for _ in range(self.beam_width - 1):\n",
    "            temp = self.sub(temp, self.vocab_size_tensor)\n",
    "            res = self.cast(self.greater_equal(temp, 0), mstype.int32)\n",
    "            beam_indices = beam_indices + res\n",
    "        word_indices = topk_indices - beam_indices * self.vocab_size_tensor\n",
    "        # ======================================================================\n",
    "\n",
    "        # mask finished indices, [batch, beam]\n",
    "        beam_indices = self.select(state_finished, self.beam_ids, beam_indices)\n",
    "        word_indices = self.select(state_finished, self.eos_ids, word_indices)\n",
    "        topk_scores = self.select(state_finished, state_log_probs, topk_scores)\n",
    "\n",
    "        # sort according to scores with -inf for finished beams, [batch, beam]\n",
    "        # t ends\n",
    "        tmp_log_probs = self.select(\n",
    "            self.equal(word_indices, self.eos_ids),\n",
    "            self.ninf_tensor,\n",
    "            topk_scores)\n",
    "\n",
    "        _, tmp_indices = self.topk(tmp_log_probs, self.beam_width)\n",
    "        # 更新, [batch_size, beam_width, 2]\n",
    "        tmp_gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(tmp_indices, -1)))\n",
    "        # [batch_size, beam_width]\n",
    "        beam_indices = self.gather_nd(beam_indices, tmp_gather_indices)\n",
    "        word_indices = self.gather_nd(word_indices, tmp_gather_indices)\n",
    "        topk_scores = self.gather_nd(topk_scores, tmp_gather_indices)\n",
    "\n",
    "        # 用于选择活动光束的聚集索引\n",
    "        gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(beam_indices, -1)))\n",
    "\n",
    "        # 如果在前一步中没有完成，长度加1, [batch_size, beam_width]\n",
    "        length_add = self.add(state_length, self.one)\n",
    "        state_length = self.select(state_finished, state_length, length_add)\n",
    "        state_length = self.gather_nd(state_length, gather_indices)\n",
    "        # 拼接 seq\n",
    "        seq = self.gather_nd(state_seq, gather_indices)\n",
    "        # 更新 accu_attn_scores\n",
    "        accu_attn_scores = self.gather_nd(accu_attn_scores, gather_indices)\n",
    "        # 更新 all_decoder_state\n",
    "        all_decoder_state = self.reshape(all_decoder_state,\n",
    "                                         (self.decoder_layers_nums * 2, self.batch_size, self.beam_width,\n",
    "                                          self.hidden_size))\n",
    "        for i in range(self.decoder_layers_nums * 2):\n",
    "            all_decoder_state[i, :, :, :] = self.gather_nd(all_decoder_state[i, :, :, :], gather_indices)\n",
    "        all_decoder_state = self.reshape(all_decoder_state,\n",
    "                                         (self.decoder_layers_nums, 2, self.batch_size * self.beam_width,\n",
    "                                          self.hidden_size))\n",
    "\n",
    "        # 更新 state_seq\n",
    "        if self.is_using_while:\n",
    "            state_seq_new = self.cast(seq, mstype.float32)\n",
    "            word_indices_fp32 = self.cast(word_indices, mstype.float32)\n",
    "            state_seq_new[:, :, idx] = word_indices_fp32\n",
    "            state_seq = self.cast(state_seq_new, mstype.int32)\n",
    "        else:\n",
    "            state_seq = self.concat((seq, self.expand(word_indices, -1)))\n",
    "\n",
    "        cur_input_ids = self.reshape(word_indices, (-1, 1))\n",
    "        state_log_probs = topk_scores\n",
    "        state_finished = self.equal(word_indices, self.eos_ids)\n",
    "\n",
    "        return cur_input_ids, state_log_probs, state_seq, state_length, \\\n",
    "               all_decoder_state, accu_attn_scores, state_finished\n",
    "\n",
    "    def construct(self, enc_states, enc_attention_mask):\n",
    "        \"\"\"\n",
    "        处理源句子\n",
    "\n",
    "        Inputs:\n",
    "            enc_states (Tensor): Output of transformer encoder with shape (batch_size * beam_width, T, D).\n",
    "            enc_attention_mask (Tensor): encoder attention mask with shape (batch_size * beam_width, T).\n",
    "\n",
    "        Returns:\n",
    "            Tensor, predictions output.\n",
    "        \"\"\"\n",
    "        # 开始beam search 算法\n",
    "        cur_input_ids = self.start_ids\n",
    "        state_log_probs = self.init_scores\n",
    "        state_seq = self.init_seq\n",
    "        state_finished = self.init_finished\n",
    "        state_length = self.init_length\n",
    "        decoder_hidden_state = self.decoder_hidden_state\n",
    "        accu_attn_scores = self.accu_attn_scores\n",
    "\n",
    "        if not self.is_using_while:\n",
    "            for _ in range(self.max_decode_length):\n",
    "                cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \\\n",
    "                state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs,\n",
    "                                               state_seq, state_length, None, decoder_hidden_state, accu_attn_scores,\n",
    "                                               state_finished)\n",
    "        else:\n",
    "            # At present, only ascend910 supports while operation.\n",
    "            idx = self.start + 1\n",
    "            ends = self.start + self.max_decode_length + 1\n",
    "            while idx < ends:\n",
    "                cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \\\n",
    "                state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs,\n",
    "                                               state_seq, state_length, idx, decoder_hidden_state, accu_attn_scores,\n",
    "                                               state_finished)\n",
    "                idx = idx + 1\n",
    "\n",
    "        # 添加长度惩罚系数\n",
    "        penalty_len = self.length_penalty(state_length)\n",
    "        log_probs = self.real_div(state_log_probs, penalty_len)\n",
    "        penalty_cov = C.clip_by_value(accu_attn_scores, 0.0, 1.0)\n",
    "        penalty_cov = self.log(penalty_cov)\n",
    "        penalty_less = self.less(penalty_cov, self.neg_inf_3d)\n",
    "        penalty = self.select(penalty_less, self.zeros_3d, penalty_cov)\n",
    "        penalty = self.reducesum(penalty, 2)\n",
    "        log_probs = log_probs + penalty * self.cov_penalty_factor\n",
    "        # 根据惩罚系数排列\n",
    "        _, top_beam_indices = self.topk(log_probs, self.beam_width)\n",
    "        gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1)))\n",
    "        # 按照序列顺序和Attention分数排列\n",
    "        predicted_ids = self.gather_nd(state_seq, gather_indices)\n",
    "        if not self.is_using_while:\n",
    "            predicted_ids = predicted_ids[:, 0:1, 1:(self.max_decode_length + 1)]\n",
    "        else:\n",
    "            predicted_ids = predicted_ids[:, 0:1, :self.max_decode_length]\n",
    "\n",
    "        return predicted_ids\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.4.7 GNMT v2\n",
    "\n",
    "GNMTv2模型主要由Encoder、Decoder和Attention机制组成，其中Encoder和Decoder使用共享的词嵌入向量。Encoder：由四个长短期记忆 （LSTM） 层组成。第一个 LSTM 层是双向的，而其他三层是单向的。Decoder：由四个单向 LSTM 层和一个完全连接的分类器组成。LSTM 的输出嵌入尺寸为 1024。Attention：采用标准化的Badanau Attention机制。首先，将解码器的第一层输出用作注意力机制的输入。然后，将注意力机制的计算结果连接到解码器LSTM的输入端，作为后续LSTM层的输入。如下图3 所示\n",
    "\n",
    "<center>\n",
    "    <img style=\"border-radius: 0.3125em;\n",
    "    box-shadow: 0 2px 4px 0 rgba(34,36,38,.12),0 2px 10px 0 rgba(34,36,38,.08); width:50%; height:50%\" \n",
    "    src=\"https://github.com/NVIDIA/DeepLearningExamples/raw/master/PyTorch/Translation/GNMT/img/diagram.png\">\n",
    "    <br>\n",
    "    <div style=\"color:orange; border-bottom: 1px solid #d9d9d9;\n",
    "    display: inline-block;\n",
    "    color: #999;\n",
    "    padding: 2px;\">图3. GNMT v2结构图</div>\n",
    "</center>\n",
    "\n",
    "依据上图原理，我们实现GNMT v2模型如下:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "class GNMT(nn.Cell):\n",
    "    \"\"\"\n",
    "    GNMT with encoder and decoder.\n",
    "\n",
    "    In GNMT, we define T = src_max_len, T' = tgt_max_len.\n",
    "\n",
    "    Args:\n",
    "        config: Model config.\n",
    "        is_training (bool): Whether is training.\n",
    "        use_one_hot_embeddings (bool): Whether use one-hot embedding.\n",
    "\n",
    "    Returns:\n",
    "        Tuple[Tensor], network outputs.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 config,\n",
    "                 is_training: bool = False,\n",
    "                 use_one_hot_embeddings: bool = False,\n",
    "                 use_positional_embedding: bool = True,\n",
    "                 compute_type=mstype.float32):\n",
    "        super(GNMT, self).__init__()\n",
    "\n",
    "        self.input_mask_from_dataset = config.input_mask_from_dataset\n",
    "        self.max_positions = config.seq_length\n",
    "        self.attn_embed_dim = config.hidden_size\n",
    "\n",
    "        config = copy.deepcopy(config)\n",
    "        self.is_training = is_training\n",
    "        self.num_layers = config.num_hidden_layers\n",
    "        self.hidden_dropout_prob = config.hidden_dropout_prob\n",
    "        self.vocab_size = config.vocab_size\n",
    "        self.seq_length = config.seq_length\n",
    "        self.batch_size = config.batch_size\n",
    "        self.max_decode_length = config.max_decode_length\n",
    "        self.word_embed_dim = config.hidden_size\n",
    "\n",
    "        self.beam_width = config.beam_width\n",
    "\n",
    "        self.transpose = P.Transpose()\n",
    "        self.transpose_orders = (1, 0, 2)\n",
    "        self.embedding_lookup = EmbeddingLookup(\n",
    "            is_training=self.is_training,\n",
    "            vocab_size=self.vocab_size,\n",
    "            embed_dim=self.word_embed_dim,\n",
    "            use_one_hot_embeddings=use_one_hot_embeddings)\n",
    "\n",
    "        self.gnmt_encoder = GNMTEncoder(config, is_training)\n",
    "\n",
    "        if self.is_training:\n",
    "            # use for train.\n",
    "            self.gnmt_decoder = GNMTDecoder(config, is_training)\n",
    "        else:\n",
    "            # use for infer.\n",
    "            self.expand = P.ExpandDims()\n",
    "            self.multiply = P.Mul()\n",
    "            self.reshape = P.Reshape()\n",
    "            self.create_att_paddings_from_input_paddings = CreateAttentionPaddingsFromInputPaddings(config,\n",
    "                                                                                                    is_training=False)\n",
    "            self.tile_beam = TileBeam(beam_width=config.beam_width)\n",
    "            self.cast_compute_type = SaturateCast(dst_type=compute_type)\n",
    "\n",
    "            beam_decoder_cell = BeamDecoderStep(config, use_one_hot_embeddings=use_one_hot_embeddings)\n",
    "            # link beam_search after decoder\n",
    "            self.beam_decoder = BeamSearchDecoder(\n",
    "                batch_size=config.batch_size,\n",
    "                seq_length=config.seq_length,\n",
    "                vocab_size=config.vocab_size,\n",
    "                decoder=beam_decoder_cell,\n",
    "                beam_width=config.beam_width,\n",
    "                decoder_layers_nums=config.num_hidden_layers,\n",
    "                length_penalty_weight=config.length_penalty_weight,\n",
    "                hidden_size=config.hidden_size,\n",
    "                max_decode_length=config.max_decode_length)\n",
    "            self.beam_decoder.add_flags(loop_can_unroll=True)\n",
    "        self.shape = P.Shape()\n",
    "\n",
    "    def construct(self, source_ids, source_mask=None, target_ids=None):\n",
    "        \"\"\"\n",
    "        Construct network.\n",
    "\n",
    "        In this method, T = src_max_len, T' = tgt_max_len.\n",
    "\n",
    "        Args:\n",
    "            source_ids (Tensor): Source sentences with shape (N, T).\n",
    "            source_mask (Tensor): Source sentences padding mask with shape (N, T),\n",
    "                where 0 indicates padding position.\n",
    "            target_ids (Tensor): Target sentences with shape (N, T').\n",
    "\n",
    "        Returns:\n",
    "            Tuple[Tensor], network outputs.\n",
    "        \"\"\"\n",
    "\n",
    "        # Process source sentences. src_embeddings:[N, T, D].\n",
    "        src_embeddings, _ = self.embedding_lookup(source_ids)\n",
    "        # T, N, D\n",
    "        inputs = self.transpose(src_embeddings, self.transpose_orders)\n",
    "        # encoder. encoder_outputs: [T, N, D]\n",
    "        encoder_outputs = self.gnmt_encoder(inputs)\n",
    "\n",
    "        # decoder.\n",
    "        if self.is_training:\n",
    "            # training\n",
    "            # process target input sentences. N, T, D\n",
    "            tgt_embeddings, _ = self.embedding_lookup(target_ids)\n",
    "            # T, N, D\n",
    "            tgt_embeddings = self.transpose(tgt_embeddings, self.transpose_orders)\n",
    "            # cell: [T,N,D].\n",
    "            cell, _, _ = self.gnmt_decoder(tgt_embeddings,\n",
    "                                           encoder_outputs,\n",
    "                                           attention_mask=source_mask)\n",
    "            # decoder_output: (N, T', V).\n",
    "            decoder_outputs = self.transpose(cell, self.transpose_orders)\n",
    "            out = decoder_outputs\n",
    "        else:\n",
    "            # infer\n",
    "            # encoder_output:  [T, N, D] -> [N, T, D].\n",
    "            beam_encoder_output = self.transpose(encoder_outputs, self.transpose_orders)\n",
    "            # bean search for encoder output, [N*beam_width, T, D]\n",
    "            beam_encoder_output = self.tile_beam(beam_encoder_output)\n",
    "\n",
    "            # (N*beam_width, T)\n",
    "            beam_enc_attention_pad = self.tile_beam(source_mask)\n",
    "\n",
    "            predicted_ids = self.beam_decoder(beam_encoder_output, beam_enc_attention_pad)\n",
    "            predicted_ids = self.reshape(predicted_ids, (-1, self.max_decode_length))\n",
    "            out = predicted_ids\n",
    "\n",
    "        return out\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.5 GNMTTraining\n",
    "为了方便后续模型训练，我们定义`GNMTTraining`类。\n",
    "\n",
    "该类将上述定义的`GNMT`模型和`PredLogProbs`模型封装，能够对`GNMT`模型进行训练并利用`PredLogProbs`类返回预测结果概率。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GNMTTraining(nn.Cell):\n",
    "    \"\"\"\n",
    "    GNMT training network.\n",
    "\n",
    "    Args:\n",
    "        config: The config of GNMT.\n",
    "        is_training (bool): Specifies whether to use the training mode.\n",
    "        use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings.\n",
    "\n",
    "    Returns:\n",
    "        Tensor, prediction_scores.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, config, is_training, use_one_hot_embeddings):\n",
    "        super(GNMTTraining, self).__init__()\n",
    "        self.gnmt = GNMT(config, is_training, use_one_hot_embeddings)\n",
    "        self.projection = PredLogProbs(\n",
    "            batch_size=config.batch_size * config.beam_width,\n",
    "            seq_length=1,\n",
    "            width=config.vocab_size,\n",
    "        )\n",
    "\n",
    "    def construct(self, source_ids, source_mask, target_ids):\n",
    "        \"\"\"\n",
    "        Construct network.\n",
    "\n",
    "        Args:\n",
    "            source_ids (Tensor): Source sentence.\n",
    "            source_mask (Tensor): Source padding mask.\n",
    "            target_ids (Tensor): Target sentence.\n",
    "\n",
    "        Returns:\n",
    "            Tensor, prediction_scores.\n",
    "        \"\"\"\n",
    "        decoder_outputs = self.gnmt(source_ids, source_mask, target_ids)\n",
    "        prediction_scores = self.projection(decoder_outputs)\n",
    "        return prediction_scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. 损失函数与优化器\n",
    "### 4.1 LabelSmoothedCrossEntropyCriterion\n",
    "\n",
    "在将深度学习模型用于分类任务时，我们通常会遇到以下问题：过拟合和过度自信。对过拟合的研究非常深入，可以通过Early Stop， Droppout，Weight Decacy等方法解决。另一方面，我们缺乏解决过度自信的工具。Label Smooth 是解决这两个问题的正则化技术。通过对 label 进行 weighted sum，能够取得比 one hot label 更好的效果。\n",
    "\n",
    "label smoothing 将 label 由 $y_k$ 转化为 $y_k^{LS}$，公式为：\n",
    "$$\n",
    "y_k^{LS}=y_k(1-\\alpha)+\\frac{\\alpha}{K}\n",
    "$$\n",
    "\n",
    "\n",
    "标签平滑用y_hot和均匀分布的混合替换一键编码的标签矢量y_hot：\n",
    "\n",
    "$$\n",
    "y_{LS} =（1- α）* y_hot + α / K\n",
    "$$\n",
    "其中K是标签类别的数量，而α是确定平滑量的超参数。如果α = 0，我们获得原始的一热编码y_hot。如果α = 1，我们得到均匀分布。\n",
    "\n",
    "当损失函数是交叉熵时，使用标签平滑，模型将softmax函数应用于倒数第二层的对数向量z，以计算其输出概率p。在这种设置下，交叉熵损失函数相对于对数的梯度很简单\n",
    "\n",
    "$$\n",
    "∇CE= p - y = softmax（z）-y\n",
    "$$\n",
    "\n",
    "其中y是标签分布。特别是，我们可以看到\n",
    "\n",
    "- 梯度下降将尝试使p尽可能接近y。\n",
    "- 渐变范围介于-1和1。\n",
    "一键编码的标签鼓励将最大的logit间隙输入到softmax函数中。直观地，大的logit间隙与有限梯度相结合将使模型的适应性降低，并且对其预测过于自信。\n",
    "\n",
    "定义本文的损失函数`LabelSmoothedCrossEntropyCriterion`如下：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LabelSmoothedCrossEntropyCriterion(nn.Cell):\n",
    "    \"\"\"\n",
    "    Label Smoothed Cross-Entropy Criterion.\n",
    "\n",
    "    Args:\n",
    "        config: The config of GNMT.\n",
    "\n",
    "    Returns:\n",
    "        Tensor, final loss.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, config):\n",
    "        super(LabelSmoothedCrossEntropyCriterion, self).__init__()\n",
    "        self.vocab_size = config.vocab_size\n",
    "        self.batch_size = config.batch_size\n",
    "        self.smoothing = 0.1\n",
    "        self.confidence = 0.9\n",
    "        self.last_idx = (-1,)\n",
    "        self.reduce_sum = P.ReduceSum()\n",
    "        self.reduce_mean = P.ReduceMean()\n",
    "        self.reshape = P.Reshape()\n",
    "        self.neg = P.Neg()\n",
    "        self.cast = P.Cast()\n",
    "        self.index_ids = Tensor(np.arange(config.batch_size * config.max_decode_length).reshape((-1, 1)), mstype.int32)\n",
    "        self.gather_nd = P.GatherNd()\n",
    "        self.expand = P.ExpandDims()\n",
    "        self.concat = P.Concat(axis=-1)\n",
    "\n",
    "    def construct(self, prediction_scores, label_ids, label_weights):\n",
    "        \"\"\"\n",
    "        Construct network to calculate loss.\n",
    "\n",
    "        Args:\n",
    "            prediction_scores (Tensor): Prediction scores. [batchsize, seq_len, vocab_size]\n",
    "            label_ids (Tensor): Labels. [batchsize, seq_len]\n",
    "            label_weights (Tensor): Mask tensor. [batchsize, seq_len]\n",
    "\n",
    "        Returns:\n",
    "            Tensor, final loss.\n",
    "        \"\"\"\n",
    "        prediction_scores = self.reshape(prediction_scores, (-1, self.vocab_size))\n",
    "        label_ids = self.reshape(label_ids, (-1, 1))\n",
    "        label_weights = self.reshape(label_weights, (-1,))\n",
    "        tmp_gather_indices = self.concat((self.index_ids, label_ids))\n",
    "        nll_loss = self.neg(self.gather_nd(prediction_scores, tmp_gather_indices))\n",
    "        nll_loss = label_weights * nll_loss\n",
    "        smooth_loss = self.neg(self.reduce_mean(prediction_scores, self.last_idx))\n",
    "        smooth_loss = label_weights * smooth_loss\n",
    "        loss = self.reduce_sum(self.confidence * nll_loss + self.smoothing * smooth_loss, ())\n",
    "        loss = loss / self.batch_size\n",
    "        return loss\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2 优化器\n",
    "\n",
    "为了快速优化损失函数，本文定义Adam(Adaptive Moment Estimation，自适应矩估计)类，完成Adam的操作功能。\n",
    "\n",
    "Adam能够利用梯度的一阶矩估计和二阶矩估计动态调整每个参数的学习率。\n",
    "其参数更新的计算公式如下：\n",
    "\n",
    "$$\n",
    "t=t+1 \\\\\n",
    "\n",
    "moment\\_1\\_{out}=β1∗moment\\_1+(1−β_1)∗grad \\\\\n",
    "moment\\_2\\_out=β2∗moment\\_2+(1−β2)∗grad∗grad \\\\\n",
    "learning\\_rate=learning\\_rate∗\\frac{\\sqrt{1−β^t_2}}{1−β^t_1} \\\\\n",
    "param\\_out=param−learning\\_rate∗\\frac{moment\\_1}{\\sqrt{moment\\_2}+ϵ}\n",
    "\n",
    "$$\n",
    "\n",
    "相关论文参见：[引用]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from mindspore._checkparam import Validator as validator\n",
    "from mindspore.common.initializer import initializer\n",
    "from mindspore.nn import Optimizer\n",
    "from mindspore.ops import composite as C\n",
    "\n",
    "_learning_rate_update_func = ['linear', 'cos', 'sin']\n",
    "\n",
    "adam_opt = C.MultitypeFuncGraph(\"adam_opt\")\n",
    "\n",
    "#注册参数类型，自动更新\n",
    "@adam_opt.register(\"Tensor\", \"Tensor\", \"Tensor\", \"Tensor\", \"Tensor\", \"Tensor\", \"Tensor\", \"Tensor\", \"Tensor\", \"Bool\")\n",
    "def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):\n",
    "    \"\"\"\n",
    "    参数更新.\n",
    "\n",
    "    Args:\n",
    "        beta1 (Tensor): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0).\n",
    "        beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0).\n",
    "        eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.\n",
    "        lr (Tensor): Learning rate.\n",
    "        weight_decay_tensor (Tensor): Weight decay. Should be equal to or greater than 0.\n",
    "        param (Tensor): Parameters.\n",
    "        m (Tensor): m value of parameters.\n",
    "        v (Tensor): v value of parameters.\n",
    "        gradient (Tensor): Gradient of parameters.\n",
    "\n",
    "    Returns:\n",
    "        Tensor, the new value of v after updating.\n",
    "    \"\"\"\n",
    "    op_mul = P.Mul()\n",
    "    op_square = P.Square()\n",
    "    op_sqrt = P.Sqrt()\n",
    "    op_cast = P.Cast()\n",
    "    op_reshape = P.Reshape()\n",
    "    op_shape = P.Shape()\n",
    "\n",
    "    param_fp32 = op_cast(param, mstype.float32)\n",
    "    m_fp32 = op_cast(m, mstype.float32)\n",
    "    v_fp32 = op_cast(v, mstype.float32)\n",
    "    gradient_fp32 = op_cast(gradient, mstype.float32)\n",
    "\n",
    "    next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)\n",
    "\n",
    "    next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)\n",
    "                                            - beta2, op_square(gradient_fp32))\n",
    "\n",
    "    update = next_m / (op_sqrt(next_v) + eps)\n",
    "    if decay_flag:\n",
    "        update = update + op_mul(weight_decay_tensor, param_fp32)\n",
    "\n",
    "    update_with_lr = op_mul(lr, update)\n",
    "    next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))\n",
    "\n",
    "    next_v = F.depend(next_v, F.assign(param, next_param))\n",
    "    next_v = F.depend(next_v, F.assign(m, next_m))\n",
    "    next_v = F.depend(next_v, F.assign(v, next_v))\n",
    "    return next_v\n",
    "\n",
    "\n",
    "def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):\n",
    "    \"\"\"Check the type of inputs.\"\"\"\n",
    "    validator.check_value_type(\"beta1\", beta1, [float], prim_name)\n",
    "    validator.check_value_type(\"beta2\", beta2, [float], prim_name)\n",
    "    validator.check_value_type(\"eps\", eps, [float], prim_name)\n",
    "    validator.check_value_type(\"weight_dacay\", weight_decay, [float], prim_name)\n",
    "\n",
    "\n",
    "\n",
    "def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, prim_name):\n",
    "    \"\"\"Check the type of inputs.\"\"\"\n",
    "    validator.check_float_positive('learning_rate', learning_rate, prim_name)\n",
    "    validator.check_float_legal_value('learning_rate', learning_rate, prim_name)\n",
    "    validator.check_float_positive('end_learning_rate', end_learning_rate, prim_name)\n",
    "    validator.check_float_legal_value('end_learning_rate', end_learning_rate, prim_name)\n",
    "    validator.check_float_positive('power', power, prim_name)\n",
    "    validator.check_float_legal_value('power', power, prim_name)\n",
    "    validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)\n",
    "\n",
    "\n",
    "@adam_opt.register(\"Function\", \"Tensor\", \"Tensor\", \"Tensor\", \"Tensor\", \"Number\", \"Tensor\", \"Tensor\", \"Tensor\", \"Tensor\",\n",
    "                   \"Tensor\")\n",
    "def _run_opt_with_one_number(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, moment1,\n",
    "                             moment2):\n",
    "    \"\"\"Apply adam optimizer to the weight parameter using Tensor.\"\"\"\n",
    "    success = True\n",
    "    success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,\n",
    "                                    eps, gradient))\n",
    "    return success\n",
    "\n",
    "class Adam(Optimizer):\n",
    "    def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,\n",
    "                 use_nesterov=False, weight_decay=0.0, loss_scale=1.0):\n",
    "        super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale)\n",
    "        _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)\n",
    "        validator.check_value_type(\"use_locking\", use_locking, [bool], self.cls_name)\n",
    "        validator.check_value_type(\"use_nesterov\", use_nesterov, [bool], self.cls_name)\n",
    "        validator.check_value_type(\"loss_scale\", loss_scale, [float], self.cls_name)\n",
    "\n",
    "        self.beta1 = Tensor(beta1, mstype.float32)\n",
    "        self.beta2 = Tensor(beta2, mstype.float32)\n",
    "        self.beta1_power = Parameter(initializer(1, [1], mstype.float32))\n",
    "        self.beta2_power = Parameter(initializer(1, [1], mstype.float32))\n",
    "        self.eps = eps\n",
    "\n",
    "        self.moment1 = self.parameters.clone(prefix=\"moment1\", init='zeros')\n",
    "        self.moment2 = self.parameters.clone(prefix=\"moment2\", init='zeros')\n",
    "\n",
    "        self.hyper_map = C.HyperMap()\n",
    "        self.opt = P.Adam(use_locking, use_nesterov)\n",
    "\n",
    "        self.pow = P.Pow()\n",
    "        self.sqrt = P.Sqrt()\n",
    "        self.one = Tensor(np.array([1.0]).astype(np.float32))\n",
    "        self.realdiv = P.RealDiv()\n",
    "\n",
    "        self.lr_scalar = P.ScalarSummary()\n",
    "\n",
    "        self.exec_mode = context.get_context(\"mode\")\n",
    "\n",
    "    def construct(self, gradients):\n",
    "        \"\"\"Adam optimizer.\"\"\"\n",
    "        params = self.parameters\n",
    "        moment1 = self.moment1\n",
    "        moment2 = self.moment2\n",
    "        gradients = self.decay_weight(gradients)\n",
    "        gradients = self.scale_grad(gradients)\n",
    "        lr = self.get_lr()\n",
    "\n",
    "        #currently, Summary operators only support graph mode\n",
    "        if self.exec_mode == context.GRAPH_MODE:\n",
    "            self.lr_scalar(\"learning_rate\", lr)\n",
    "\n",
    "        beta1_power = self.beta1_power * self.beta1\n",
    "        self.beta1_power = beta1_power\n",
    "        beta2_power = self.beta2_power * self.beta2\n",
    "        self.beta2_power = beta2_power\n",
    "        if self.is_group_lr:\n",
    "            success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,\n",
    "                                               self.beta2, self.eps),\n",
    "                                     lr, gradients, params, moment1, moment2)\n",
    "        else:\n",
    "            success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,\n",
    "                                               self.beta2, self.eps, lr),\n",
    "                                     gradients, params, moment1, moment2)\n",
    "        return success\n",
    "\n",
    "def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):\n",
    "    \"\"\"Check the type of inputs.\"\"\"\n",
    "    validator.check_value_type(\"beta1\", beta1, [float], prim_name)\n",
    "    validator.check_value_type(\"beta2\", beta2, [float], prim_name)\n",
    "    validator.check_value_type(\"eps\", eps, [float], prim_name)\n",
    "    validator.check_value_type(\"weight_dacay\", weight_decay, [float], prim_name)\n",
    "\n",
    "def _get_optimizer(config, network, lr):\n",
    "    \"\"\"get gnmt optimizer, support Adam, Lamb, Momentum.\"\"\"\n",
    "    optimizer = Adam(network.trainable_params(), lr, beta1=0.9, beta2=0.98)\n",
    "    return optimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.3 学习率\n",
    "模型初始训练时，模型的权重是随机初始化的，初始学习率太大，可能会导致模型的震荡（不稳定）。\n",
    "\n",
    "本文选择了学习率预热方法，在初始的几个epoch或者是steps内用一个比较小的学习率，这样可以保证网络能够具有良好的收敛性。\n",
    "\n",
    "但是较低的学习率会使得训练过程变得非常缓慢，因此采用较低学习率逐渐增大至较高学习率的方式实现网络训练的“热身”阶段，训练完制定的epoch或者steps之后恢复设置的初始学习率。\n",
    "\n",
    "为了防止从较小学习率到指定的初始学习率变化较大引起误差增大，利用gradual warmup集束缓解这个问题，通过每个steps逐渐增大lr，直到达到指定的初始学习率后再开始学习率的下降。\n",
    "逐渐预热的公式为：\n",
    "$$\n",
    "\\alpha^{\\prime}_t=\\frac{t}{T^{\\prime}}, 1\\leq t \\leq T^{\\prime}\n",
    "$$\n",
    "\n",
    "根据上述公式，实现`Warmup_MultiStepLR_scheduler`如下：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "def Warmup_MultiStepLR_scheduler(base_lr=0.002, total_update_num=200, warmup_steps=200, remain_steps=1.0,\n",
    "                                 decay_interval=-1, decay_steps=4, decay_factor=0.5):\n",
    "    \"\"\"\n",
    "    Implements of polynomial decay learning rate scheduler which cycles by default.\n",
    "\n",
    "    Args:\n",
    "        base_lr (float): Initial learning rate.\n",
    "        total_update_num (int): Total update steps.\n",
    "        warmup_steps (int or float): Warmup steps.\n",
    "        remain_steps (int or float): start decay at 'remain_steps' iteration\n",
    "        decay_interval (int): interval between LR decay steps\n",
    "        decay_steps (int): Decay steps.\n",
    "        decay_factor (float): decay factor\n",
    "\n",
    "    Returns:\n",
    "        np.ndarray, learning rate of each step.\n",
    "    \"\"\"\n",
    "\n",
    "    if decay_steps <= 0:\n",
    "        raise ValueError(\"`decay_steps` must larger than 1.\")\n",
    "    remain_steps = convert_float2int(remain_steps, total_update_num)\n",
    "    warmup_steps = convert_float2int(warmup_steps, total_update_num)\n",
    "    if warmup_steps > remain_steps:\n",
    "        warmup_steps = remain_steps\n",
    "\n",
    "    if decay_interval < 0:\n",
    "        decay_iterations = total_update_num - remain_steps\n",
    "        decay_interval = decay_iterations // decay_steps\n",
    "        decay_interval = max(decay_interval, 1)\n",
    "    else:\n",
    "        decay_interval = convert_float2int(decay_interval, total_update_num)\n",
    "\n",
    "    lrs = np.zeros(shape=total_update_num, dtype=np.float32)\n",
    "    _start_step = 0\n",
    "    for last_epoch in range(_start_step, total_update_num):\n",
    "        if last_epoch < warmup_steps:\n",
    "            if warmup_steps != 0:\n",
    "                warmup_factor = math.exp(math.log(0.01) / warmup_steps)\n",
    "            else:\n",
    "                warmup_factor = 1.0\n",
    "            inv_decay = warmup_factor ** (warmup_steps - last_epoch)\n",
    "            lrs[last_epoch] = base_lr * inv_decay\n",
    "        elif last_epoch >= remain_steps:\n",
    "            decay_iter = last_epoch - remain_steps\n",
    "            num_decay_step = decay_iter // decay_interval + 1\n",
    "            num_decay_step = min(num_decay_step, decay_steps)\n",
    "            lrs[last_epoch] = base_lr * (decay_factor ** num_decay_step)\n",
    "        else:\n",
    "            lrs[last_epoch] = base_lr\n",
    "    return lrs\n",
    "\n",
    "def convert_float2int(values, total_steps):\n",
    "    if isinstance(values, float):\n",
    "        values = int(values * total_steps)\n",
    "    return values\n",
    "\n",
    "def _get_lr(config, update_steps):\n",
    "    \"\"\"generate learning rate.\"\"\"\n",
    "    lr = Tensor(Warmup_MultiStepLR_scheduler(base_lr=config.lr,\n",
    "                                            total_update_num=update_steps,\n",
    "                                                 warmup_steps=config.warmup_steps,\n",
    "                                                 remain_steps=config.warmup_lr_remain_steps,\n",
    "                                                 decay_interval=config.warmup_lr_decay_interval,\n",
    "                                                 decay_steps=config.decay_steps,\n",
    "                                                 decay_factor=config.lr_scheduler_power), dtype=mstype.float32)\n",
    "    return lr"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. 模型训练与保存\n",
    "### 5.1 模型训练\n",
    "本文定义模型训练函数，该函数加载需要训练的模型、模型配置参数、模型训练过程中的回调函数和数据集。\n",
    "\n",
    "通过设置模式为预训练或微调，即可利用`MindSpore`中自带的`Model`类开始训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def _train(model, config,\n",
    "           pre_training_dataset=None, fine_tune_dataset=None, test_dataset=None,\n",
    "           callbacks: list = None):\n",
    "    \"\"\"\n",
    "    Train model.\n",
    "\n",
    "    Args:\n",
    "        model (Model): MindSpore model instance.\n",
    "        config: Config of mass model.\n",
    "        pre_training_dataset (Dataset): Pre-training dataset.\n",
    "        fine_tune_dataset (Dataset): Fine-tune dataset.\n",
    "        test_dataset (Dataset): Test dataset.\n",
    "        callbacks (list): A list of callbacks.\n",
    "    \"\"\"\n",
    "    callbacks = callbacks if callbacks else []\n",
    "\n",
    "    if pre_training_dataset is not None:\n",
    "        print(\" | Start pre-training job.\")\n",
    "        epoch_size = pre_training_dataset.get_repeat_count()\n",
    "        print(\"epoch size \", epoch_size)\n",
    "        model.train(config.epochs, pre_training_dataset,\n",
    "                    callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode)\n",
    "\n",
    "    if fine_tune_dataset is not None:\n",
    "        print(\" | Start fine-tuning job.\")\n",
    "        epoch_size = fine_tune_dataset.get_repeat_count()\n",
    "\n",
    "        model.train(config.epochs, fine_tune_dataset,\n",
    "                    callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.2 GNMTNetworkWithLoss\n",
    "定义`GNMTNetworkWithLoss`类，该类用于调整模型训练过程。\n",
    "\n",
    "为了更好的实现模型训练，我们这里将`GNMTTraining`训练对象和`LabelSmoothedCrossEntropyCriterion`损失函数对象封装在一个对象中，可以方便的计算模型训练过程中损失。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GNMTNetworkWithLoss(nn.Cell):\n",
    "    \"\"\"\n",
    "    用于为训练网络提供GNMT训练损失。\n",
    "\n",
    "    Args:\n",
    "        config (BertConfig): The config of GNMT.\n",
    "        is_training (bool): Specifies whether to use the training mode.\n",
    "        use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.\n",
    "\n",
    "    Returns:\n",
    "        Tensor, the loss of the network.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, config, is_training, use_one_hot_embeddings=False):\n",
    "        super(GNMTNetworkWithLoss, self).__init__()\n",
    "        self.gnmt = GNMTTraining(config, is_training, use_one_hot_embeddings)\n",
    "        self.loss = LabelSmoothedCrossEntropyCriterion(config)\n",
    "        self.cast = P.Cast()\n",
    "\n",
    "    def construct(self,\n",
    "                  source_ids,\n",
    "                  source_mask,\n",
    "                  target_ids,\n",
    "                  label_ids,\n",
    "                  label_weights):\n",
    "        prediction_scores = self.gnmt(source_ids, source_mask, target_ids)\n",
    "        total_loss = self.loss(prediction_scores, label_ids, label_weights)\n",
    "        return self.cast(total_loss, mstype.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.3 GNMTTrainOneStepWithLossScaleCell\n",
    "定义`GNMTTrainOneStepWithLossScaleCell`类型继承于`TrainOneStepWithLossScaleCell`。\n",
    "\n",
    "该类是可以使用混合精度功能的训练网络，它使用网络、优化器和用于更新损失缩放系数（loss scale）的Cell(或一个Tensor)作为参数。\n",
    "\n",
    "可在host侧或device侧更新损失缩放系数。 如果需要在host侧更新，使用Tensor作为 scale_sense ，否则，使用可更新损失缩放系数的Cell实例作为 scale_sense。\n",
    "\n",
    "这里我们实现了对于GNMT模型的包含损失缩放（loss scale）的单次训练类，在实例化该对象之后，可以将优化器附加到训练网络函数来创建反向图。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore.nn.wrap.grad_reducer import DistributedGradReducer\n",
    "from mindspore.communication.management import get_group_size\n",
    "from mindspore.ops import functional as F\n",
    "\n",
    "from mindspore.common import dtype as mstype\n",
    "from mindspore.nn.wrap.grad_reducer import DistributedGradReducer\n",
    "from mindspore.communication.management import get_group_size\n",
    "\n",
    "GRADIENT_CLIP_TYPE = 1\n",
    "GRADIENT_CLIP_VALUE = 5.0\n",
    "\n",
    "clip_grad = C.MultitypeFuncGraph(\"clip_grad\")\n",
    "\n",
    "grad_scale = C.MultitypeFuncGraph(\"grad_scale\")\n",
    "reciprocal = P.Reciprocal()\n",
    "\n",
    "\n",
    "@clip_grad.register(\"Number\", \"Number\", \"Tensor\")\n",
    "def _clip_grad(clip_type, clip_value, grad):\n",
    "    \"\"\"\n",
    "    Clip gradients.\n",
    "\n",
    "    Inputs:\n",
    "        clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.\n",
    "        clip_value (float): Specifies how much to clip.\n",
    "        grad (tuple[Tensor]): Gradients.\n",
    "\n",
    "    Outputs:\n",
    "        tuple[Tensor], clipped gradients.\n",
    "    \"\"\"\n",
    "    if clip_type not in (0, 1):\n",
    "        return grad\n",
    "    dt = F.dtype(grad)\n",
    "    if clip_type == 0:\n",
    "        new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),\n",
    "                                   F.cast(F.tuple_to_array((clip_value,)), dt))\n",
    "    else:\n",
    "        new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))\n",
    "    return new_grad\n",
    "\n",
    "\n",
    "@grad_scale.register(\"Tensor\", \"Tensor\")\n",
    "def tensor_grad_scale(scale, grad):\n",
    "    return grad * F.cast(reciprocal(scale), F.dtype(grad))\n",
    "\n",
    "\n",
    "_grad_overflow = C.MultitypeFuncGraph(\"_grad_overflow\")\n",
    "grad_overflow = P.FloatStatus()\n",
    "\n",
    "\n",
    "@_grad_overflow.register(\"Tensor\")\n",
    "def _tensor_grad_overflow(grad):\n",
    "    return grad_overflow(grad)\n",
    "\n",
    "\n",
    "class GNMTTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):\n",
    "    \"\"\"\n",
    "    Encapsulation class of GNMT network training.\n",
    "\n",
    "    Append an optimizer to the training network after that the construct\n",
    "    function can be called to create the backward graph.\n",
    "\n",
    "    Args:\n",
    "        network: Cell. The training network. Note that loss function should have\n",
    "            been added.\n",
    "        optimizer: Optimizer. Optimizer for updating the weights.\n",
    "\n",
    "    Returns:\n",
    "        Tuple[Tensor, Tensor, Tensor], loss, overflow, sen.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, network, optimizer, scale_update_cell=None):\n",
    "\n",
    "        super(GNMTTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell)\n",
    "        self.cast = P.Cast()\n",
    "        self.degree = 1\n",
    "        if self.reducer_flag:\n",
    "            self.degree = get_group_size()\n",
    "            self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)\n",
    "\n",
    "        self.loss_scale = None\n",
    "        self.loss_scaling_manager = scale_update_cell\n",
    "        if scale_update_cell:\n",
    "            self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))\n",
    "\n",
    "    def construct(self,\n",
    "                  source_eos_ids,\n",
    "                  source_eos_mask,\n",
    "                  target_sos_ids,\n",
    "                  target_eos_ids,\n",
    "                  target_eos_mask,\n",
    "                  sens=None):\n",
    "        \"\"\"\n",
    "        Construct network.\n",
    "\n",
    "        Args:\n",
    "            source_eos_ids (Tensor): Source sentence.\n",
    "            source_eos_mask (Tensor): Source padding mask.\n",
    "            target_sos_ids (Tensor): Target sentence.\n",
    "            target_eos_ids (Tensor): Prediction sentence.\n",
    "            target_eos_mask (Tensor): Prediction padding mask.\n",
    "            sens (Tensor): Loss sen.\n",
    "\n",
    "        Returns:\n",
    "            Tuple[Tensor, Tensor, Tensor], loss, overflow, sen.\n",
    "        \"\"\"\n",
    "        source_ids = source_eos_ids\n",
    "        source_mask = source_eos_mask\n",
    "        target_ids = target_sos_ids\n",
    "        label_ids = target_eos_ids\n",
    "        label_weights = target_eos_mask\n",
    "\n",
    "        weights = self.weights\n",
    "        loss = self.network(source_ids,\n",
    "                            source_mask,\n",
    "                            target_ids,\n",
    "                            label_ids,\n",
    "                            label_weights)\n",
    "        if sens is None:\n",
    "            scaling_sens = self.loss_scale\n",
    "        else:\n",
    "            scaling_sens = sens\n",
    "\n",
    "        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)\n",
    "\n",
    "        grads = self.grad(self.network, weights)(source_ids,\n",
    "                                                 source_mask,\n",
    "                                                 target_ids,\n",
    "                                                 label_ids,\n",
    "                                                 label_weights,\n",
    "                                                 self.cast(scaling_sens,\n",
    "                                                           mstype.float32))\n",
    "        # apply grad reducer on grads\n",
    "        grads = self.grad_reducer(grads)\n",
    "        grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)\n",
    "        grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)\n",
    "\n",
    "        cond = self.get_overflow_status(status, grads)\n",
    "        overflow = cond\n",
    "        if sens is None:\n",
    "            overflow = self.loss_scaling_manager(self.loss_scale, cond)\n",
    "        if not overflow:\n",
    "            self.optimizer(grads)\n",
    "        return (loss, cond, scaling_sens)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.4 损失函数回调接口\n",
    "定义`LossCallBack`类在每一轮训练完成时，执行相应操作，方便查看训练情况。\n",
    "\n",
    "我们这里将训练耗费的时间、当前轮次、当前step、当前模型的loss等信息输出到日志文件中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from mindspore.train.callback import Callback\n",
    "\n",
    "class LossCallBack(Callback):\n",
    "    \"\"\"\n",
    "    Monitor the loss in training.\n",
    "\n",
    "    If the loss is NAN or INF terminating training.\n",
    "\n",
    "    Note:\n",
    "        If per_print_times is 0 do not print loss.\n",
    "\n",
    "    Args:\n",
    "        per_print_times (int): Print loss every times. Default: 1.\n",
    "    \"\"\"\n",
    "    time_stamp_init = False\n",
    "    time_stamp_first = 0\n",
    "\n",
    "    def __init__(self, config, per_print_times: int = 1):\n",
    "        super(LossCallBack, self).__init__()\n",
    "        if not isinstance(per_print_times, int) or per_print_times < 0:\n",
    "            raise ValueError(\"print_step must be int and >= 0.\")\n",
    "        self.config = config\n",
    "        self._per_print_times = per_print_times\n",
    "\n",
    "        if not self.time_stamp_init:\n",
    "            self.time_stamp_first = self._get_ms_timestamp()\n",
    "            self.time_stamp_init = True\n",
    "\n",
    "    def step_end(self, run_context):\n",
    "        \"\"\"step end.\"\"\"\n",
    "        cb_params = run_context.original_args()\n",
    "        file_name = \"./loss.log\"\n",
    "        with open(file_name, \"a+\") as f:\n",
    "            time_stamp_current = self._get_ms_timestamp()\n",
    "            f.write(\"time: {}, epoch: {}, step: {}, outputs: [loss: {}, overflow: {}, loss scale value: {} ].\\n\".format(\n",
    "                time_stamp_current - self.time_stamp_first,\n",
    "                cb_params.cur_epoch_num,\n",
    "                cb_params.cur_step_num,\n",
    "                str(cb_params.net_outputs[0].asnumpy()),\n",
    "                str(cb_params.net_outputs[1].asnumpy()),\n",
    "                str(cb_params.net_outputs[2].asnumpy())\n",
    "            ))\n",
    "\n",
    "    @staticmethod\n",
    "    def _get_ms_timestamp():\n",
    "        t = time.time()\n",
    "        return int(round(t * 1000))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.5 模型训练保存配置\n",
    "定义模型训练的配置方法，用于设置模型必要的参数。\n",
    "\n",
    "首先设定模型加载`config`文件中的各项参数，初始化`GNMTNetworkWithLoss`参数，将这个带有损失函数的模型加载到`GNMTTrainOneStepWithLossScaleCell`模型训练`Cell`中。\n",
    "\n",
    "随后将该模型训练`Cell`送入`Model`类中作为一个整体模型类。\n",
    "\n",
    "最后为其配置`CheckPoint`保存点和每一个`step`需要得一些回调函数即完成全部模型装配，调用`_train`方法即可开始训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore.train.model import Model\n",
    "from mindspore.train.loss_scale_manager import DynamicLossScaleManager\n",
    "from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, SummaryCollector, TimeMonitor\n",
    "\n",
    "def _build_training_pipeline(config,\n",
    "                             pre_training_dataset=None,\n",
    "                             fine_tune_dataset=None,\n",
    "                             test_dataset=None):\n",
    "    \"\"\"\n",
    "    Build training pipeline.\n",
    "\n",
    "    Args:\n",
    "        config: Config of mass model.\n",
    "        pre_training_dataset (Dataset): Pre-training dataset.\n",
    "        fine_tune_dataset (Dataset): Fine-tune dataset.\n",
    "        test_dataset (Dataset): Test dataset.\n",
    "    \"\"\"\n",
    "    net_with_loss = GNMTNetworkWithLoss(config, is_training=True, use_one_hot_embeddings=True)\n",
    "    net_with_loss.init_parameters_data()\n",
    "    # _load_checkpoint_to_net(config, net_with_loss)\n",
    "\n",
    "    dataset = pre_training_dataset if pre_training_dataset is not None \\\n",
    "        else fine_tune_dataset\n",
    "\n",
    "    if dataset is None:\n",
    "        raise ValueError(\"pre-training dataset or fine-tuning dataset must be provided one.\")\n",
    "\n",
    "    update_steps = config.epochs * dataset.get_dataset_size()\n",
    "\n",
    "    lr = _get_lr(config, update_steps)\n",
    "    optimizer = _get_optimizer(config, net_with_loss, lr)\n",
    "\n",
    "    # Dynamic loss scale.\n",
    "    scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale,\n",
    "                                            scale_factor=config.loss_scale_factor,\n",
    "                                            scale_window=config.scale_window)\n",
    "    net_with_grads = GNMTTrainOneStepWithLossScaleCell(\n",
    "        network=net_with_loss, optimizer=optimizer,\n",
    "        scale_update_cell=scale_manager.get_update_cell()\n",
    "    )\n",
    "    net_with_grads.set_train(True)\n",
    "    model = Model(net_with_grads)\n",
    "    loss_monitor = LossCallBack(config)\n",
    "    dataset_size = dataset.get_dataset_size()\n",
    "    time_cb = TimeMonitor(data_size=dataset_size)\n",
    "    ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_steps,\n",
    "                                   keep_checkpoint_max=config.keep_ckpt_max)\n",
    "\n",
    "    rank_size = os.getenv('RANK_SIZE')\n",
    "    callbacks = [time_cb, loss_monitor]\n",
    "\n",
    "    if rank_size is None or int(rank_size) == 1:\n",
    "        ckpt_callback = ModelCheckpoint(\n",
    "            prefix=config.ckpt_prefix,\n",
    "            directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(config.device_id)),\n",
    "            config=ckpt_config)\n",
    "        callbacks.append(ckpt_callback)\n",
    "        summary_callback = SummaryCollector(summary_dir=\"./summary\", collect_freq=50)\n",
    "        callbacks.append(summary_callback)\n",
    "\n",
    "    print(f\" | ALL SET, PREPARE TO TRAIN.\")\n",
    "    _train(model=model, config=config,\n",
    "           pre_training_dataset=pre_training_dataset,\n",
    "           fine_tune_dataset=fine_tune_dataset,\n",
    "           test_dataset=test_dataset,\n",
    "           callbacks=callbacks)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.6 开始训练\n",
    "\n",
    "将配置参数文件和数据集加载进context，接着将数据集划分为训练数据集和测试数据集。\n",
    "\n",
    "准备好相关数据后即可开始启动训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "text_translation\n",
      " | Starting training on single device.\n",
      " | Loading ../data/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord.\n",
      " | Dataset size: 2534982.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(30680:140204289664512,MainProcess):2022-11-08-07:29:40.507.505 [mindspore/train/model.py:1077] For LossCallBack callback, {'step_end'} methods may not be supported in later version, Use methods prefixed with 'on_train' or 'on_eval' instead when using customized callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " | ALL SET, PREPARE TO TRAIN.\n",
      " | Start pre-training job.\n",
      "epoch size  1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.475.264 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:254] CalMemBlockAllocSize] Memory not enough: current free memory size[399441920] is smaller than required size[827392000].\n",
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.475.322 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:535] DumpDynamicMemPoolDebugInfo] Start dump dynamic memory pool debug info.\n",
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.475.339 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:494] operator()] Common mem all mem_block info: counts[8].\n",
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.475.356 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:498] operator()]  MemBlock info: number[0] mem_buf_counts[2] base_address[0x7f7c98000000] block_size[2147483648].\n",
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.475.404 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:498] operator()]  MemBlock info: number[1] mem_buf_counts[13] base_address[0x7f7d18000000] block_size[1073741824].\n",
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.475.422 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:498] operator()]  MemBlock info: number[2] mem_buf_counts[3] base_address[0x7f7d58000000] block_size[1073741824].\n",
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.475.437 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:498] operator()]  MemBlock info: number[3] mem_buf_counts[2] base_address[0x7f7d98000000] block_size[1073741824].\n",
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.475.456 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:498] operator()]  MemBlock info: number[4] mem_buf_counts[16] base_address[0x7f7dd8000000] block_size[1073741824].\n",
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.475.533 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:498] operator()]  MemBlock info: number[5] mem_buf_counts[42] base_address[0x7f7fca000000] block_size[1073741824].\n",
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.475.563 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:498] operator()]  MemBlock info: number[6] mem_buf_counts[28] base_address[0x7f811a000000] block_size[1073741824].\n",
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.475.957 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:498] operator()]  MemBlock info: number[7] mem_buf_counts[206] base_address[0x7f816a000000] block_size[1073741824].\n",
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.476.002 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:515] operator()] Common mem all idle mem_buf info: counts[23].\n",
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.476.022 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:525] operator()] Common mem total allocated memory[9663676416], used memory[8243255296], idle memory[1420421120].\n",
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.476.035 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:494] operator()] Persistent mem all mem_block info: counts[1].\n",
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.476.391 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:498] operator()]  MemBlock info: number[0] mem_buf_counts[326] base_address[0x7f800a000000] block_size[1073741824].\n",
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.476.457 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:515] operator()] Persistent mem all idle mem_buf info: counts[1].\n",
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.476.471 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:525] operator()] Persistent mem total allocated memory[1073741824], used memory[1073739264], idle memory[2560].\n",
      "[WARNING] PRE_ACT(30680,7f82c4ffa700,python):2022-11-08-07:30:19.476.484 [mindspore/ccsrc/common/mem_reuse/mem_dynamic_allocator.cc:538] DumpDynamicMemPoolDebugInfo] Finish dump dynamic memory pool debug info.\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Device(id:0) memory isn't enough and alloc failed, kernel name: Default/Add-op1563, alloc size: 827392000B.\n\n----------------------------------------------------\n- C++ Call Stack: (For framework developers)\n----------------------------------------------------\nmindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc:628 Run\n",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_30680/4066724352.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     85\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m \u001b[0mdo_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/tmp/ipykernel_30680/4066724352.py\u001b[0m in \u001b[0;36mdo_train\u001b[0;34m()\u001b[0m\n\u001b[1;32m     82\u001b[0m                              \u001b[0mpre_training_dataset\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpre_train_dataset\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     83\u001b[0m                              \u001b[0mfine_tune_dataset\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfine_tune_dataset\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 84\u001b[0;31m                              test_dataset=test_dataset)\n\u001b[0m\u001b[1;32m     85\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_30680/3318493948.py\u001b[0m in \u001b[0;36m_build_training_pipeline\u001b[0;34m(config, pre_training_dataset, fine_tune_dataset, test_dataset)\u001b[0m\n\u001b[1;32m     64\u001b[0m            \u001b[0mfine_tune_dataset\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfine_tune_dataset\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     65\u001b[0m            \u001b[0mtest_dataset\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtest_dataset\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 66\u001b[0;31m            callbacks=callbacks)\n\u001b[0m",
      "\u001b[0;32m/tmp/ipykernel_30680/3454950244.py\u001b[0m in \u001b[0;36m_train\u001b[0;34m(model, config, pre_training_dataset, fine_tune_dataset, test_dataset, callbacks)\u001b[0m\n\u001b[1;32m     20\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"epoch size \"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     21\u001b[0m         model.train(config.epochs, pre_training_dataset,\n\u001b[0;32m---> 22\u001b[0;31m                     callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode)\n\u001b[0m\u001b[1;32m     23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     24\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mfine_tune_dataset\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/python-3.7.5/lib/python3.7/site-packages/mindspore/train/model.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, epoch, train_dataset, callbacks, dataset_sink_mode, sink_size, initial_epoch)\u001b[0m\n\u001b[1;32m   1047\u001b[0m                     \u001b[0mdataset_sink_mode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdataset_sink_mode\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1048\u001b[0m                     \u001b[0msink_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msink_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1049\u001b[0;31m                     initial_epoch=initial_epoch)\n\u001b[0m\u001b[1;32m   1050\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1051\u001b[0m         \u001b[0;31m# When it's Parameter Server training and using MindRT,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/python-3.7.5/lib/python3.7/site-packages/mindspore/train/model.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m     96\u001b[0m                 \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     97\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 98\u001b[0;31m             \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     99\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/python-3.7.5/lib/python3.7/site-packages/mindspore/train/model.py\u001b[0m in \u001b[0;36m_train\u001b[0;34m(self, epoch, train_dataset, callbacks, dataset_sink_mode, sink_size, initial_epoch, valid_dataset, valid_frequency, valid_dataset_sink_mode)\u001b[0m\n\u001b[1;32m    621\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    622\u001b[0m                 self._train_dataset_sink_process(epoch, train_dataset, list_callback,\n\u001b[0;32m--> 623\u001b[0;31m                                                  cb_params, sink_size, initial_epoch, valid_infos)\n\u001b[0m\u001b[1;32m    624\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    625\u001b[0m     \u001b[0;34m@\u001b[0m\u001b[0mstaticmethod\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/python-3.7.5/lib/python3.7/site-packages/mindspore/train/model.py\u001b[0m in \u001b[0;36m_train_dataset_sink_process\u001b[0;34m(self, epoch, train_dataset, list_callback, cb_params, sink_size, initial_epoch, valid_infos)\u001b[0m\n\u001b[1;32m    699\u001b[0m                 \u001b[0mlist_callback\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_train_step_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    700\u001b[0m                 \u001b[0mtrain_network\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_network_mode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_network\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 701\u001b[0;31m                 \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_network\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    702\u001b[0m                 \u001b[0mcb_params\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnet_outputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    703\u001b[0m                 \u001b[0;31m# In disaster recovery scenarios, need not to execute callbacks if this step executes failed.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/python-3.7.5/lib/python3.7/site-packages/mindspore/nn/cell.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    576\u001b[0m                 logger.warning(f\"For 'Cell', it's not support hook function in graph mode. If you want to use hook \"\n\u001b[1;32m    577\u001b[0m                                f\"function, please use context.set_context to set pynative mode.\")\n\u001b[0;32m--> 578\u001b[0;31m             \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompile_and_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    579\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    580\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/python-3.7.5/lib/python3.7/site-packages/mindspore/nn/cell.py\u001b[0m in \u001b[0;36mcompile_and_run\u001b[0;34m(self, *inputs)\u001b[0m\n\u001b[1;32m    986\u001b[0m                 \u001b[0mparallel_inputs_run\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_inputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    987\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0m_cell_graph_executor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mparallel_inputs_run\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mphase\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 988\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0m_cell_graph_executor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mnew_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mphase\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    989\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    990\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mauto_parallel_compile_and_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/python-3.7.5/lib/python3.7/site-packages/mindspore/common/api.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, obj, phase, *args)\u001b[0m\n\u001b[1;32m   1200\u001b[0m            \u001b[0;34m(\u001b[0m\u001b[0m_is_role_pserver\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0m_enable_distributed_mindrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_is_role_sched\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1201\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1202\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mphase\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1203\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1204\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mhas_compiled\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mphase\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'predict'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/python-3.7.5/lib/python3.7/site-packages/mindspore/common/api.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, obj, phase, *args)\u001b[0m\n\u001b[1;32m   1237\u001b[0m         \u001b[0mphase_real\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mphase\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m'.'\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcreate_time\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m'.'\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m'.'\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marguments_key\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1238\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhas_compiled\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mphase_real\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1239\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_exec_pip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mphase\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mphase_real\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1240\u001b[0m         \u001b[0;32mraise\u001b[0m \u001b[0mKeyError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'{} graph is not exist.'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mphase_real\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1241\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/python-3.7.5/lib/python3.7/site-packages/mindspore/common/api.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*arg, **kwargs)\u001b[0m\n\u001b[1;32m     96\u001b[0m     \u001b[0;34m@\u001b[0m\u001b[0mwraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     97\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 98\u001b[0;31m         \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     99\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0m_convert_python_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/python-3.7.5/lib/python3.7/site-packages/mindspore/common/api.py\u001b[0m in \u001b[0;36m_exec_pip\u001b[0;34m(self, obj, phase, *args)\u001b[0m\n\u001b[1;32m   1219\u001b[0m         \u001b[0mfn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconstruct\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1220\u001b[0m         \u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__parse_method__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1221\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_graph_executor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mphase\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1222\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1223\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mphase\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'predict'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Device(id:0) memory isn't enough and alloc failed, kernel name: Default/Add-op1563, alloc size: 827392000B.\n\n----------------------------------------------------\n- C++ Call Stack: (For framework developers)\n----------------------------------------------------\nmindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc:628 Run\n"
     ]
    }
   ],
   "source": [
    "from mindspore.common import set_seed\n",
    "from collections import namedtuple\n",
    "from mindspore import context\n",
    "\n",
    "train_config_values = {\n",
    "    'device_target': \"GPU\",\n",
    "    'device_id':0,\n",
    "\n",
    "    'random_seed': 50,\n",
    "    'epochs': 8,\n",
    "    'batch_size': 128,\n",
    "    'pre_train_dataset': [\"./data/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord\"],\n",
    "    'fine_tune_dataset': \"\",\n",
    "    'test_dataset':  \"\",\n",
    "    'valid_dataset': \"\",\n",
    "    'dataset_sink_mode': True,\n",
    "    'input_mask_from_dataset': False,\n",
    "    # model_config\n",
    "    'seq_length': 51,\n",
    "    'vocab_size': 32320,\n",
    "    'hidden_size': 1024,\n",
    "    'num_hidden_layers': 4,\n",
    "    'intermediate_size': 4096,\n",
    "    'hidden_dropout_prob': 0.2,\n",
    "    'attention_dropout_prob': 0.2,\n",
    "    'initializer_range': 0.1,\n",
    "    'label_smoothing': 0.1,\n",
    "    'beam_width': 2,\n",
    "    'length_penalty_weight': 0.6,\n",
    "    'max_decode_length': 50,\n",
    "\n",
    "    # loss_scale_config\n",
    "    'init_loss_scale': 65536,\n",
    "    'loss_scale_factor': 2,\n",
    "    'scale_window': 1000,\n",
    "\n",
    "    # learn_rate_config\n",
    "    'optimizer': \"adam\",\n",
    "    'lr': 0.002 ,# 2e-3\n",
    "    'lr_scheduler': \"WarmupMultiStepLR\",\n",
    "    'lr_scheduler_power': 0.5,\n",
    "    'warmup_lr_remain_steps': 0.666,\n",
    "    'warmup_lr_decay_interval': -1,\n",
    "    'decay_steps': 4,\n",
    "    'decay_start_step': -1,\n",
    "    'warmup_steps': 200,\n",
    "    'min_lr': 0.000001 ,#1e-6\n",
    "\n",
    "    # checkpoint_options\n",
    "    'existed_ckpt': \"\",\n",
    "    'save_ckpt_steps': 3452,\n",
    "    'keep_ckpt_max': 8,\n",
    "    'ckpt_prefix': \"gnmt\",\n",
    "    'ckpt_path': \"text_translation\"\n",
    "\n",
    "}\n",
    "\n",
    "train_config_keys = namedtuple('train_config', list(train_config_values.keys()))\n",
    "\n",
    "train_config = train_config_keys._make(train_config_values.values())\n",
    "print(train_config.ckpt_path)\n",
    "\n",
    "def do_train():\n",
    "    _rank_size = os.getenv('RANK_SIZE')\n",
    "    context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target=train_config.device_target,\n",
    "                        reserve_class_name_in_scope=True)\n",
    "\n",
    "    set_seed(train_config.random_seed)\n",
    "\n",
    "    print(\" | Starting training on single device.\")\n",
    "    pre_train_dataset = load_dataset(data_files=train_config.pre_train_dataset,\n",
    "                                     batch_size=train_config.batch_size,\n",
    "                                     sink_mode=train_config.dataset_sink_mode) if train_config.pre_train_dataset else None\n",
    "    fine_tune_dataset = load_dataset(data_files=train_config.fine_tune_dataset,\n",
    "                                     batch_size=train_config.batch_size,\n",
    "                                     sink_mode=train_config.dataset_sink_mode) if train_config.fine_tune_dataset else None\n",
    "    test_dataset = load_dataset(data_files=train_config.test_dataset,\n",
    "                                batch_size=train_config.batch_size,\n",
    "                                sink_mode=train_config.dataset_sink_mode) if train_config.test_dataset else None\n",
    "\n",
    "    _build_training_pipeline(config=train_config,\n",
    "                             pre_training_dataset=pre_train_dataset,\n",
    "                             fine_tune_dataset=fine_tune_dataset,\n",
    "                             test_dataset=test_dataset)\n",
    "\n",
    "\n",
    "do_train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. 模型加载与评估\n",
    "### 6.1 模型加载\n",
    "\n",
    "本文首先定义基于`MindSpore`的模型加载方法。\n",
    "如果存在之前训练得到的ckpt格式模型参数文件，那么就从该模型文件中加载权重，如果是Parameter类型，是模型训练的`weights`参数，需要调用data属性在通过`set_data()`方法直接载入。\n",
    "\n",
    "如果是Tensor类型，需要给先将参数转成numpy对应类型，然后在转换成配置文件中的`dtype`类型，在调用`set_data()`方法装配。\n",
    "\n",
    "如果是numpy的`ndarray`数组类型，也需要转换成配置文件中的`dtype`类型，在利用set_data()装配。\n",
    "\n",
    "如果不存在训练好的ckpt模型参数文件，对于gamma参数会产生一个全1矩阵，而其他参数则是通过`random.uniform`随机产生。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore.train.serialization import load_checkpoint\n",
    "\n",
    "def _load_checkpoint_to_net(config, network):\n",
    "    \"\"\"load parameters to network from checkpoint.\"\"\"\n",
    "    if config.existed_ckpt:\n",
    "        if config.existed_ckpt.endswith(\".npz\"):\n",
    "            weights = np.load(config.existed_ckpt)\n",
    "        else:\n",
    "            weights = load_checkpoint(config.existed_ckpt)\n",
    "        for param in network.trainable_params():\n",
    "            weights_name = param.name\n",
    "            if weights_name not in weights:\n",
    "                raise ValueError(f\"Param {weights_name} is not found in ckpt file.\")\n",
    "\n",
    "            if isinstance(weights[weights_name], Parameter):\n",
    "                param.set_data(weights[weights_name].data)\n",
    "            elif isinstance(weights[weights_name], Tensor):\n",
    "                param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))\n",
    "            elif isinstance(weights[weights_name], np.ndarray):\n",
    "                param.set_data(Tensor(weights[weights_name], config.dtype))\n",
    "            else:\n",
    "                param.set_data(weights[weights_name])\n",
    "    else:\n",
    "        for param in network.trainable_params():\n",
    "            name = param.name\n",
    "            value = param.data\n",
    "            if isinstance(value, Tensor):\n",
    "                if name.endswith(\".gamma\"):\n",
    "                    param.set_data(one_weight(value.asnumpy().shape))\n",
    "                elif name.endswith(\".beta\") or name.endswith(\".bias\"):\n",
    "                    if param.data.dtype == \"Float32\":\n",
    "                        param.set_data((weight_variable(value.asnumpy().shape).astype(np.float32)))\n",
    "                    elif param.data.dtype == \"Float16\":\n",
    "                        param.set_data((weight_variable(value.asnumpy().shape).astype(np.float16)))\n",
    "                else:\n",
    "                    if param.data.dtype == \"Float32\":\n",
    "                        param.set_data(Tensor(weight_variable(value.asnumpy().shape).astype(np.float32)))\n",
    "                    elif param.data.dtype == \"Float16\":\n",
    "                        param.set_data(Tensor(weight_variable(value.asnumpy().shape).astype(np.float16)))\n",
    "\n",
    "def weight_variable(shape):\n",
    "    \"\"\"\n",
    "    Generate weight var.\n",
    "\n",
    "    Args:\n",
    "        shape (tuple): Shape.\n",
    "\n",
    "    Returns:\n",
    "        Tensor, var.\n",
    "    \"\"\"\n",
    "    limit = 0.1\n",
    "    values = np.random.uniform(-limit, limit, shape)\n",
    "    return values\n",
    "\n",
    "\n",
    "def one_weight(shape):\n",
    "    \"\"\"\n",
    "    Generate weight with ones.\n",
    "\n",
    "    Args:\n",
    "        shape (tuple): Shape.\n",
    "\n",
    "    Returns:\n",
    "        Tensor, var.\n",
    "    \"\"\"\n",
    "    ones = np.ones(shape).astype(np.float32)\n",
    "    return Tensor(ones)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.2 推理计算\n",
    "\n",
    "#### 6.2.1 推理权重加载\n",
    "\n",
    "模型推理权重与模型训练权重在参数上并不一致，因此我们单独定义加载推理权重方法`load_infer_weights()`方法。\n",
    "\n",
    "如果是numpy的权重文件，我们可以直接将其加载进模型参数；如果是ckpt的Tensor参数，我们需要先将其转换为numpy，在从numpy中加载配置文件中要求的Tensor向量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_infer_weights(config):\n",
    "    \"\"\"\n",
    "    Load weights from ckpt or npz.\n",
    "\n",
    "    Args:\n",
    "        config: Config.\n",
    "\n",
    "    Returns:\n",
    "        dict, weights.\n",
    "    \"\"\"\n",
    "    model_path = config.existed_ckpt\n",
    "    if model_path.endswith(\".npz\"):\n",
    "        ms_ckpt = np.load(model_path)\n",
    "        is_npz = True\n",
    "    else:\n",
    "        ms_ckpt = load_checkpoint(model_path)\n",
    "        is_npz = False\n",
    "    weights = {}\n",
    "    for param_name in ms_ckpt:\n",
    "        infer_name = param_name.replace(\"gnmt.gnmt.\", \"\")\n",
    "        if infer_name.startswith(\"embedding_lookup.\"):\n",
    "            if is_npz:\n",
    "                weights[infer_name] = ms_ckpt[param_name]\n",
    "            else:\n",
    "                weights[infer_name] = ms_ckpt[param_name].data.asnumpy()\n",
    "            infer_name = \"beam_decoder.decoder.\" + infer_name\n",
    "            if is_npz:\n",
    "                weights[infer_name] = ms_ckpt[param_name]\n",
    "            else:\n",
    "                weights[infer_name] = ms_ckpt[param_name].data.asnumpy()\n",
    "            continue\n",
    "        elif not infer_name.startswith(\"gnmt_encoder\"):\n",
    "            if infer_name.startswith(\"gnmt_decoder.\"):\n",
    "                infer_name = infer_name.replace(\"gnmt_decoder.\", \"decoder.\")\n",
    "            infer_name = \"beam_decoder.decoder.\" + infer_name\n",
    "\n",
    "        if is_npz:\n",
    "            weights[infer_name] = ms_ckpt[param_name]\n",
    "        else:\n",
    "            weights[infer_name] = ms_ckpt[param_name].data.asnumpy()\n",
    "    return weights"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 6.2.2 GNMTInferCell推理类\n",
    "定义`GNMTInferCell`类，能该类利用生成的网络完成推理。\n",
    "\n",
    "将`GNMT`模型的推理封装进`GNMTInferCell`推理类，使用该类能够装配`Model`后直接开始推理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GNMTInferCell(nn.Cell):\n",
    "    \"\"\"\n",
    "    Encapsulation class of GNMT network infer.\n",
    "\n",
    "    Args:\n",
    "        network (nn.Cell): GNMT model.\n",
    "\n",
    "    Returns:\n",
    "        Tuple[Tensor, Tensor], predicted_ids and predicted_probs.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, network):\n",
    "        super(GNMTInferCell, self).__init__(auto_prefix=False)\n",
    "        self.network = network\n",
    "\n",
    "    def construct(self,\n",
    "                  source_ids,\n",
    "                  source_mask):\n",
    "        \"\"\"Defines the computation performed.\"\"\"\n",
    "\n",
    "        predicted_ids = self.network(source_ids,\n",
    "                                     source_mask)\n",
    "\n",
    "        return predicted_ids"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 6.2.3 gnmt_infer推理方法\n",
    "\n",
    "本文定义`gnmt_infer()`方法来配置模型推理所需要的各种参数。\n",
    "\n",
    "推理流程如下：\n",
    "\n",
    "1. 首先实例化GNMT模型，并将之前训练好保存的模型权重和偏置等参数加载到实例化模型中。\n",
    "2. 使用`dataset.create_dict_iterator()`从数据集中创建训练`batch`。值得注意的是，如果最后一轮`batch`划分的数据不够`batch_size`，这里会为其填充无意义的`source_ids_pad`和无意义的`source_mask_pad`。\n",
    "3. 模型推理完成后，将其组成{\"source\": '', \"prediction\": ''}的字典列表返回到`infer()`方法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def infer(config):\n",
    "    \"\"\"\n",
    "    GNMT infer api.\n",
    "\n",
    "    Args:\n",
    "        config: Config.\n",
    "\n",
    "    Returns:\n",
    "        list, result with\n",
    "    \"\"\"\n",
    "    eval_dataset = load_dataset(data_files=config.test_dataset,\n",
    "                                batch_size=config.batch_size,\n",
    "                                sink_mode=config.dataset_sink_mode,\n",
    "                                drop_remainder=False,\n",
    "                                is_translate=True,\n",
    "                                shuffle=False) if config.test_dataset else None\n",
    "    prediction = gnmt_infer(config, eval_dataset)\n",
    "    return prediction\n",
    "\n",
    "def gnmt_infer(config, dataset):\n",
    "    \"\"\"\n",
    "    Run infer with GNMT.\n",
    "\n",
    "    Args:\n",
    "        config: Config.\n",
    "        dataset (Dataset): Dataset.\n",
    "\n",
    "    Returns:\n",
    "        List[Dict], prediction, each example has 4 keys, \"source\",\n",
    "        \"target\", \"prediction\" and \"prediction_prob\".\n",
    "    \"\"\"\n",
    "    tfm_model = GNMT(config=config,\n",
    "                     is_training=False,\n",
    "                     use_one_hot_embeddings=False)\n",
    "\n",
    "    params = tfm_model.trainable_params()\n",
    "    weights = load_infer_weights(config)\n",
    "    for param in params:\n",
    "        value = param.data\n",
    "        weights_name = param.name\n",
    "        if weights_name not in weights:\n",
    "            raise ValueError(f\"{weights_name} is not found in weights.\")\n",
    "        if isinstance(value, Tensor):\n",
    "            if weights_name in weights:\n",
    "                assert weights_name in weights\n",
    "                if isinstance(weights[weights_name], Parameter):\n",
    "                    if param.data.dtype == \"Float32\":\n",
    "                        param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))\n",
    "                    elif param.data.dtype == \"Float16\":\n",
    "                        param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))\n",
    "\n",
    "                elif isinstance(weights[weights_name], Tensor):\n",
    "                    param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))\n",
    "                elif isinstance(weights[weights_name], np.ndarray):\n",
    "                    param.set_data(Tensor(weights[weights_name], config.dtype))\n",
    "                else:\n",
    "                    param.set_data(weights[weights_name])\n",
    "            else:\n",
    "                print(\"weight not found in checkpoint: \" + weights_name)\n",
    "                param.set_data(zero_weight(value.asnumpy().shape))\n",
    "\n",
    "    print(\" | Load weights successfully.\")\n",
    "    tfm_infer = GNMTInferCell(tfm_model)\n",
    "    model = Model(tfm_infer)\n",
    "\n",
    "    predictions = []\n",
    "    source_sentences = []\n",
    "\n",
    "    shape = P.Shape()\n",
    "    concat = P.Concat(axis=0)\n",
    "    batch_index = 1\n",
    "    pad_idx = 0\n",
    "    sos_idx = 2\n",
    "    eos_idx = 3\n",
    "    source_ids_pad = Tensor(np.tile(np.array([[sos_idx, eos_idx] + [pad_idx] * (config.seq_length - 2)]),\n",
    "                                    [config.batch_size, 1]), mstype.int32)\n",
    "    source_mask_pad = Tensor(np.tile(np.array([[1, 1] + [0] * (config.seq_length - 2)]),\n",
    "                                     [config.batch_size, 1]), mstype.int32)\n",
    "    for batch in dataset.create_dict_iterator():\n",
    "        source_sentences.append(batch[\"source_eos_ids\"].asnumpy())\n",
    "        source_ids = Tensor(batch[\"source_eos_ids\"], mstype.int32)\n",
    "        source_mask = Tensor(batch[\"source_eos_mask\"], mstype.int32)\n",
    "\n",
    "        active_num = shape(source_ids)[0]\n",
    "        if active_num < config.batch_size:\n",
    "            source_ids = concat((source_ids, source_ids_pad[active_num:, :]))\n",
    "            source_mask = concat((source_mask, source_mask_pad[active_num:, :]))\n",
    "\n",
    "        start_time = time.time()\n",
    "        predicted_ids = model.predict(source_ids, source_mask)\n",
    "\n",
    "        print(f\" | BatchIndex = {batch_index}, Batch size: {config.batch_size}, active_num={active_num}, \"\n",
    "              f\"Time cost: {time.time() - start_time}.\")\n",
    "        if active_num < config.batch_size:\n",
    "            predicted_ids = predicted_ids[:active_num, :]\n",
    "        batch_index = batch_index + 1\n",
    "        predictions.append(predicted_ids.asnumpy())\n",
    "\n",
    "    output = []\n",
    "    for inputs, batch_out in zip(source_sentences, predictions):\n",
    "        for i, _ in enumerate(batch_out):\n",
    "            if batch_out.ndim == 3:\n",
    "                batch_out = batch_out[:, 0]\n",
    "\n",
    "            example = {\n",
    "                \"source\": inputs[i].tolist(),\n",
    "                \"prediction\": batch_out[i].tolist()\n",
    "            }\n",
    "            output.append(example)\n",
    "\n",
    "    return output\n",
    "\n",
    "def zero_weight(shape):\n",
    "    \"\"\"\n",
    "    Generate weight with zeros.\n",
    "\n",
    "    Args:\n",
    "        shape (tuple): Shape.\n",
    "\n",
    "    Returns:\n",
    "        Tensor, var.\n",
    "    \"\"\"\n",
    "    zeros = np.zeros(shape).astype(np.float32)\n",
    "    return Tensor(zeros)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.3 衡量指标\n",
    "\n",
    "机器翻译模型中使用BLEU评价指标来衡量模型性能，同样的，我们这里也使用BLEU指标。\n",
    "\n",
    "首先加载进来的文本数据使用`Tokenizer`分词，随后将分词后数据使用sacrebleu快速计算得到`BLEU`分数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "\n",
    "def load_result_data(result_npy_addr):\n",
    "    # load the numpy to list.\n",
    "    result = np.load(result_npy_addr, allow_pickle=True)\n",
    "    return result\n",
    "\n",
    "def get_bleu_data(tokenizer: Tokenizer, result_npy_addr):\n",
    "    \"\"\"\n",
    "    Detokenizer the prediction.\n",
    "\n",
    "    Args:\n",
    "        tokenizer (Tokenizer): tokenizer operations.\n",
    "        result_npy_addr (string): Path to the predict file.\n",
    "\n",
    "    Returns:\n",
    "        List, the predict text context.\n",
    "    \"\"\"\n",
    "\n",
    "    result = load_result_data(result_npy_addr)\n",
    "    prediction_list = []\n",
    "    for _, info in enumerate(result):\n",
    "        # prediction detokenize\n",
    "        prediction = info[\"prediction\"]\n",
    "        prediction_str = tokenizer.detokenize(prediction)\n",
    "        prediction_list.append(prediction_str)\n",
    "\n",
    "    return prediction_list\n",
    "\n",
    "\n",
    "def calculate_sacrebleu(predict_path, target_path):\n",
    "    \"\"\"\n",
    "    通过sacrebleu 计算 BLEU 得分.注意：如果找不到sacrebleu这个命令，建议使用绝对路径\n",
    "\n",
    "    Args:\n",
    "        predict_path (string): Path to the predict file.\n",
    "        target_path (string): Path to the target file.\n",
    "\n",
    "    Returns:\n",
    "        Float32, bleu scores.\n",
    "    \"\"\"\n",
    "    sacrebleu_params = '--score-only -lc --tokenize intl'\n",
    "    sacrebleu = subprocess.run([f'/usr/local/python-3.7.5/bin/sacrebleu --input {predict_path} \\\n",
    "                                {target_path} {sacrebleu_params}'],\n",
    "                               stdout=subprocess.PIPE, shell=True)\n",
    "    bleu_scores = round(float(sacrebleu.stdout.strip()), 2)\n",
    "    return bleu_scores\n",
    "\n",
    "\n",
    "def bleu_calculate(tokenizer, result_npy_addr, target_addr=None):\n",
    "    \"\"\"\n",
    "    Calculate the BLEU scores.\n",
    "\n",
    "    Args:\n",
    "        tokenizer (Tokenizer): tokenizer operations.\n",
    "        result_npy_addr (string): Path to the predict file.\n",
    "        target_addr (string): Path to the target file.\n",
    "\n",
    "    Returns:\n",
    "        Float32, bleu scores.\n",
    "    \"\"\"\n",
    "\n",
    "    prediction = get_bleu_data(tokenizer, result_npy_addr)\n",
    "    print(\"predict top3:\\n\")\n",
    "    for i in range(3):\n",
    "        print(prediction[i])\n",
    "\n",
    "    eval_path = './predict.txt'\n",
    "    with open(eval_path, 'w') as eval_file:\n",
    "        lines = [line + '\\n' for line in prediction]\n",
    "        eval_file.writelines(lines)\n",
    "    reference_path = target_addr\n",
    "    bleu_scores = calculate_sacrebleu(eval_path, reference_path)\n",
    "    return bleu_scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.4 开始推理\n",
    "定义推理方法`run_eval()`方法。\n",
    "\n",
    "推理方法中的过程比较简单，加载配置参数即可开始推理。\n",
    "\n",
    "将推理输出数据从`infer()`函数中取出后，将推理输出数据放入上述实现的`bleu_calculate()`中即可得到推理结果分数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'mstype' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_32035/3332741364.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     72\u001b[0m     \u001b[0;34m'output'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"./output.npz\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     73\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m     \u001b[0;34m'compute_type'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mmstype\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat16\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     75\u001b[0m     \u001b[0;34m'dtype'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mmstype\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat16\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     76\u001b[0m }\n",
      "\u001b[0;31mNameError\u001b[0m: name 'mstype' is not defined"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "\n",
    "eval_config_dict = {\n",
    "    'enable_modelarts': False,\n",
    "    # Url for modelarts\n",
    "    'data_url': \"\",\n",
    "    'train_url': \"\",\n",
    "    'checkpoint_url': \"\",\n",
    "    # Path for local\n",
    "    'data_path': \"/cache/data\",\n",
    "    'output_path': \"/cache/train\",\n",
    "    'load_path': \"/cache/checkpoint_path\",\n",
    "    'device_target': \"GPU\",\n",
    "    'device_id': 0,\n",
    "    'need_modelarts_dataset_unzip': False,\n",
    "    'modelarts_dataset_unzip_name': \"\",\n",
    "\n",
    "    # ==============================================================================\n",
    "    # dataset_config\n",
    "    'random_seed': 50,\n",
    "    'epochs': 1,\n",
    "    'batch_size': 128,\n",
    "    'pre_train_dataset': \"\",\n",
    "    'fine_tune_dataset': \"\",\n",
    "    'test_dataset':  [\"./data/dataset_menu/newstest2014.en.mindrecord\"],\n",
    "    'valid_dataset': \"\",\n",
    "    'dataset_sink_mode': True,\n",
    "    'input_mask_from_dataset': False,\n",
    "\n",
    "    # model_config\n",
    "    'seq_length': 107,\n",
    "    'vocab_size': 32320,\n",
    "    'hidden_size': 1024,\n",
    "    'num_hidden_layers': 4,\n",
    "    'intermediate_size': 4096,\n",
    "    'hidden_dropout_prob': .0,\n",
    "    'attention_dropout_prob': .0,\n",
    "    'initializer_range': 0.1,\n",
    "    'label_smoothing': 0.1,\n",
    "    'beam_width': 2,\n",
    "    'length_penalty_weight': 0.6,\n",
    "    'max_decode_length': 80,\n",
    "\n",
    "    # loss_scale_config\n",
    "    'init_loss_scale': 65536,\n",
    "    'loss_scale_factor': 2,\n",
    "    'scale_window': 1000,\n",
    "\n",
    "    # learn_rate_config\n",
    "    'optimizer': \"adam\",\n",
    "    'lr': 0.002, # 2e-3\n",
    "    'lr_scheduler': \"WarmupMultiStepLR\",\n",
    "    'lr_scheduler_power': 0.5,\n",
    "    'warmup_lr_remain_steps': 0.666,\n",
    "    'warmup_lr_decay_interval': -1,\n",
    "    'decay_steps': 4,\n",
    "    'decay_start_step': -1,\n",
    "    'warmup_steps': 200,\n",
    "    'min_lr': 0.000001, # 1e-6,\n",
    "\n",
    "    # checkpoint_options\n",
    "    'existed_ckpt': \"./text_translation/ckpt_0/gnmt-8_4179.ckpt\",\n",
    "    'save_ckpt_steps': 3452,\n",
    "    'keep_ckpt_max': 6,\n",
    "    'ckpt_prefix': \"gnmt\",\n",
    "    'ckpt_path': \"text_translation\",\n",
    "\n",
    "    # eval option\n",
    "    'bpe_codes': \"./data/wmt16_de_en/bpe.32000\",\n",
    "    'test_tgt': \"./data/wmt16_de_en/newstest2014.de\",\n",
    "    'vocab': \"./data/wmt16_de_en/vocab.bpe.32000\",\n",
    "    'output': \"./output.npz\", \n",
    "\n",
    "    'compute_type': mstype.float16,\n",
    "    'dtype': mstype.float16\n",
    "}\n",
    "\n",
    "\n",
    "eval_config_keys = namedtuple('eval_config', list(eval_config_dict.keys()))\n",
    "\n",
    "eval_config = eval_config_keys._make(eval_config_dict.values())\n",
    "\n",
    "def run_eval():\n",
    "    '''run eval.'''\n",
    "    result = infer(eval_config)\n",
    "    context.set_context(\n",
    "        mode=context.GRAPH_MODE,\n",
    "        save_graphs=False,\n",
    "        device_target=eval_config.device_target,\n",
    "        device_id=eval_config.device_id,\n",
    "        reserve_class_name_in_scope=False)\n",
    "\n",
    "    with open(eval_config.output, \"wb\") as f:\n",
    "        pickle.dump(result, f, 1)\n",
    "\n",
    "    result_npy_addr = eval_config.output\n",
    "    vocab = eval_config.vocab\n",
    "    bpe_codes = eval_config.bpe_codes\n",
    "    test_tgt = eval_config.test_tgt\n",
    "    tokenizer = Tokenizer(vocab, bpe_codes, 'en', 'de')\n",
    "    scores = bleu_calculate(tokenizer, result_npy_addr, test_tgt)\n",
    "    print(f\"BLEU scores is :{scores}\")\n",
    "\n",
    "run_eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. 总结\n",
    "本文基于MindSpore框架，从数据集构建、模型构建、训练和评估等内容完整实现了GNMT V2模型，并根据机器翻译模型的特定评价指标，实现了BULE指标的计算，最后利用wmt16英德双语数据集完成模型训练，使用NEWS16英语数据集完成模型评估。通过此模型案例，能够进一步加深对机器翻译领域、GNMT模型的理解，同时由于GNMT v2中的模型结构的Encoder部分与Google后续在2018年发布的BERT模型十分相似，因此该模型也能作为理解BERT模型的一个前置案例；更重要的是，通过使用MindSpore框架实现GNMT v2模型，更深层次的理解了MindSpore的运行原理和特点，对于后续在更多模型上探索MindSpore训练的可行性打下来夯实的基础，相信在MindSpore的不断进步下，会有更多更好的模型涌现。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. 引用\n",
    " "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.5 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
